[Core][Bugfix]Refactor block manager for better testability (#3492)
This commit is contained in:
parent
8267b06c30
commit
14ccd94c89
0
tests/core/block/__init__.py
Normal file
0
tests/core/block/__init__.py
Normal file
56
tests/core/block/e2e/conftest.py
Normal file
56
tests/core/block/e2e/conftest.py
Normal file
@ -0,0 +1,56 @@
|
||||
import contextlib
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
destroy_model_parallel)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed)
|
||||
|
||||
|
||||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
distinct_llm_kwargs, seed):
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
}
|
||||
|
||||
def generator_inner():
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
86
tests/core/block/e2e/test_correctness.py
Normal file
86
tests/core/block/e2e/test_correctness.py
Normal file
@ -0,0 +1,86 @@
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||
"block_size": 16,
|
||||
"forced_num_gpu_blocks": 5 * (64 + 1),
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify block manager v2 produces same outputs as block manager v1, even
|
||||
when there is preemption.
|
||||
|
||||
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||
is decided such that as the sequences in the batch grow, sequences must be
|
||||
preempted and removed from cache.
|
||||
|
||||
If the output token ids are equivalent, then we have confidence that the KV
|
||||
cache is not corrupted in the v2 block manager.
|
||||
|
||||
NOTE: We want a significant number of generated tokens so that any incorrect
|
||||
KV mapping has time to build up error.
|
||||
"""
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
|
||||
# We want to ensure equality even with preemption.
|
||||
# We force the total block size to be 1 + cdiv(output_len, block_size)
|
||||
# so that only one sequence can fit at a time (once the sequences grow).
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
50
tests/core/block/test_block_space_manager.py
Normal file
50
tests/core/block/test_block_space_manager.py
Normal file
@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
|
||||
from ..utils import create_seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
|
||||
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
|
||||
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
|
||||
num_gpu_blocks: int, watermark: float):
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
watermark=watermark,
|
||||
)
|
||||
num_watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
num_output_blocks_per_seq = 1
|
||||
|
||||
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
|
||||
# the current implementation assumes all seqs are new prompts / don't have
|
||||
# different output lens.
|
||||
num_output_blocks = num_output_blocks_per_seq
|
||||
|
||||
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_lens=block_size * num_prompt_blocks,
|
||||
seq_output_lens=[
|
||||
block_size * num_output_blocks_per_seq
|
||||
for _ in range(num_seqs_per_group)
|
||||
],
|
||||
)
|
||||
|
||||
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||
|
||||
can_allocate_result = block_manager.can_allocate(seq_group)
|
||||
|
||||
num_required_blocks = num_prompt_blocks + num_output_blocks
|
||||
|
||||
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
|
||||
assert can_allocate_result == AllocStatus.NEVER
|
||||
elif num_gpu_blocks >= num_required_blocks:
|
||||
assert can_allocate_result == AllocStatus.OK
|
||||
else:
|
||||
assert can_allocate_result == AllocStatus.LATER
|
500
tests/core/block/test_block_table.py
Normal file
500
tests/core/block/test_block_table.py
Normal file
@ -0,0 +1,500 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
def test_allocate_naive(block_size: int, sequence_len: int):
|
||||
"""Test the allocation of blocks using the naive allocator.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size and
|
||||
number of blocks. It then allocates multiple BlockTables with varying
|
||||
sequence lengths and verifies that the number of free blocks decreases as
|
||||
expected after each allocation.
|
||||
"""
|
||||
assert block_size > 1
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
|
||||
|
||||
block_tables = []
|
||||
for i in range(5):
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
|
||||
|
||||
block_tables.append(
|
||||
BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
))
|
||||
block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
def test_allocate_prefix_caching(block_size: int, sequence_len: int):
|
||||
"""Test the allocation of blocks using the prefix caching allocator.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size and
|
||||
number of blocks, using the prefix caching allocator. It then allocates
|
||||
multiple BlockTables with varying sequence lengths and verifies that the
|
||||
number of free blocks decreases as expected after each allocation.
|
||||
|
||||
The test expects all sequences to share allocations, except for their last
|
||||
block, which may be mutable. It calculates the expected number of immutable
|
||||
and mutable blocks per allocation based on the sequence length and block
|
||||
size.
|
||||
"""
|
||||
assert block_size > 1
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
chunked_tokens = list(chunk_list(token_ids, block_size))
|
||||
num_mutable_blocks_per_alloc = 0 if len(
|
||||
chunked_tokens[-1]) == block_size else 1
|
||||
num_immutable_blocks_per_alloc = len(
|
||||
chunked_tokens) - num_mutable_blocks_per_alloc
|
||||
|
||||
block_tables = []
|
||||
for alloc_i in range(1, 6):
|
||||
|
||||
block_tables.append(
|
||||
BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
))
|
||||
block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Expect all sequences to share allocations, except for their last block
|
||||
# (which may be mutable).
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_gpu_blocks - (
|
||||
num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc *
|
||||
(alloc_i))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
@pytest.mark.parametrize("device", ["cpu", "gpu"])
|
||||
def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str,
|
||||
device: str):
|
||||
"""Test the allocation and freeing of blocks using different allocators and
|
||||
devices.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size,
|
||||
number of blocks, allocator type, and device. It then allocates a BlockTable
|
||||
multiple times with the same sequence and verifies that the number of free
|
||||
blocks remains consistent after each allocation and freeing.
|
||||
"""
|
||||
device = Device[device.upper()]
|
||||
|
||||
num_device_blocks = 1024
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_device_blocks,
|
||||
num_cpu_blocks=num_device_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
block_table.allocate(token_ids=token_ids, device=device)
|
||||
assert allocator.get_num_free_blocks(
|
||||
device) == num_device_blocks - num_blocks_per_alloc
|
||||
assert all(block_id is not None
|
||||
for block_id in block_table.physical_block_ids)
|
||||
|
||||
block_table.free()
|
||||
assert allocator.get_num_free_blocks(device) == num_device_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_append_token_ids_allocation(block_size: int, sequence_len: int,
|
||||
append_len: int, allocator_type: str):
|
||||
"""Test the allocation behavior when appending token IDs to a BlockTable.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size,
|
||||
number of blocks, and allocator type. It then allocates a BlockTable with an
|
||||
initial sequence and appends additional token IDs to it. The test verifies
|
||||
that the number of allocated blocks before and after appending matches the
|
||||
expected values.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
num_expected_blocks_before_append = len(
|
||||
list(chunk_list(token_ids, block_size)))
|
||||
num_expected_appended_blocks = len(
|
||||
list(chunk_list(token_ids + token_ids_to_append,
|
||||
block_size))) - num_expected_blocks_before_append
|
||||
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
assert len(
|
||||
block_table.physical_block_ids) == num_expected_blocks_before_append
|
||||
block_table.append_token_ids(token_ids_to_append)
|
||||
assert len(
|
||||
block_table.physical_block_ids
|
||||
) == num_expected_blocks_before_append + num_expected_appended_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("num_empty_slots", [1, 16, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int,
|
||||
num_empty_slots: int,
|
||||
allocator_type: str):
|
||||
"""Test the allocation behavior when ensuring a certain number of empty
|
||||
slots in a BlockTable.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size,
|
||||
number of blocks, and allocator type. It then allocates a BlockTable with an
|
||||
initial sequence and ensures a certain number of empty slots. The test
|
||||
verifies that the number of allocated blocks before and after ensuring empty
|
||||
slots matches the expected values. It also checks that filling up the empty
|
||||
slots does not consume additional blocks.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
num_expected_blocks_before_append = len(
|
||||
list(chunk_list(token_ids, block_size)))
|
||||
num_expected_appended_blocks = len(
|
||||
list(chunk_list(token_ids + [-1] * num_empty_slots,
|
||||
block_size))) - num_expected_blocks_before_append
|
||||
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Assert that the empty slots consume the expected number of additional
|
||||
# blocks.
|
||||
assert len(
|
||||
block_table.physical_block_ids) == num_expected_blocks_before_append
|
||||
block_table.ensure_num_empty_slots(num_empty_slots)
|
||||
assert len(
|
||||
block_table.physical_block_ids
|
||||
) == num_expected_blocks_before_append + num_expected_appended_blocks
|
||||
|
||||
# Now, ensure no additional blocks consumed as we fill up the empty slots.
|
||||
num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU)
|
||||
block_table.append_token_ids(token_ids=list(range(num_empty_slots)))
|
||||
assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 9])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_size", [1, 4, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
|
||||
append_len: int, allocator_type: str,
|
||||
append_size: int):
|
||||
"""Verify token ids are correctly appended. Appends various amounts of
|
||||
token ids in various append sizes, and verifies the final sequence is
|
||||
correct.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
appended_so_far = []
|
||||
for append in chunk_list(token_ids_to_append, append_size):
|
||||
block_table.append_token_ids(append)
|
||||
appended_so_far.extend(append)
|
||||
|
||||
assert block_table._get_all_token_ids() == token_ids + appended_so_far
|
||||
|
||||
assert block_table._get_all_token_ids() == token_ids + token_ids_to_append
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [1, 9, 129])
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_fork(seq_len: int, block_size: int, allocator_type: str):
|
||||
"""Create a sequence using the specified allocator.
|
||||
1. Assert that after forking the sequence, the free block count is the
|
||||
same.
|
||||
2. Assert that the forked sequence has the same physical mappings.
|
||||
3. Then free the original sequence; verify that the free block count is
|
||||
the same.
|
||||
4. Finally, free the forked sequence and verify that the free block
|
||||
count drops to zero.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(seq_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
block_table.allocate(token_ids)
|
||||
|
||||
num_free_blocks_before_fork = allocator.get_num_free_blocks(
|
||||
device=Device.GPU)
|
||||
|
||||
forked_block_table = block_table.fork()
|
||||
|
||||
# Expect physical_block_ids and token_ids to match.
|
||||
assert (block_table.physical_block_ids ==
|
||||
forked_block_table.physical_block_ids)
|
||||
assert block_table._get_all_token_ids(
|
||||
) == forked_block_table._get_all_token_ids()
|
||||
|
||||
# Do not expect any additional allocations.
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_free_blocks_before_fork
|
||||
|
||||
# Free the original blocks. Assert num free blocks does not change, since
|
||||
# refcount is nonzero.
|
||||
block_table.free()
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_free_blocks_before_fork
|
||||
|
||||
# Expect the forked block table to be unaffected by the free.
|
||||
assert all(block_id is not None
|
||||
for block_id in forked_block_table.physical_block_ids)
|
||||
|
||||
# Free the forked blocks. Assert num free blocks does change, since
|
||||
# refcount is now zero.
|
||||
forked_block_table.free()
|
||||
assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("appender", ["forked", "original"])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_cow(block_size: int, sequence_len: int, append_len: int,
|
||||
allocator_type: str, appender: str):
|
||||
"""Fork a sequence; append to the forked sequence; verify there's a CoW.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
original_block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
num_expected_non_cow_blocks = cdiv(sequence_len, block_size)
|
||||
num_expected_cow_blocks = cdiv(sequence_len + append_len,
|
||||
block_size) - (sequence_len // block_size)
|
||||
|
||||
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
original_block_ids = original_block_table.physical_block_ids
|
||||
|
||||
forked_block_table = original_block_table.fork()
|
||||
|
||||
# Expect no additional allocation (copy on _write_).
|
||||
assert allocator.get_num_free_blocks(
|
||||
Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks)
|
||||
|
||||
if appender == "forked":
|
||||
appender_block_table = forked_block_table
|
||||
static_block_table = original_block_table
|
||||
elif appender == "original":
|
||||
appender_block_table = original_block_table
|
||||
static_block_table = forked_block_table
|
||||
else:
|
||||
raise ValueError(f"unknown test config {appender=}")
|
||||
|
||||
# Write tokens.
|
||||
appender_block_table.append_token_ids(token_ids_to_append)
|
||||
|
||||
# Expect the non-appending block table to have no change.
|
||||
assert static_block_table.physical_block_ids == original_block_ids
|
||||
assert appender_block_table.physical_block_ids != original_block_ids
|
||||
|
||||
# Expect the blocks changed during append to have a CoW.
|
||||
assert allocator.get_num_free_blocks(
|
||||
Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks +
|
||||
num_expected_cow_blocks)
|
||||
|
||||
cows = allocator.clear_copy_on_writes()
|
||||
if sequence_len % block_size > 0:
|
||||
# If the last block in the sequence is not full, then when appending we
|
||||
# expect a CoW.
|
||||
assert cows
|
||||
|
||||
cow_block_id = sequence_len // block_size
|
||||
expected_src = static_block_table.physical_block_ids[cow_block_id]
|
||||
expected_dst = appender_block_table.physical_block_ids[cow_block_id]
|
||||
|
||||
assert expected_src in cows
|
||||
assert expected_dst in cows[expected_src]
|
||||
else:
|
||||
# Otherwise, there should be no copy-on-write.
|
||||
assert not cows
|
||||
|
||||
static_block_table.free()
|
||||
appender_block_table.free()
|
||||
|
||||
# After free, expect all blocks to be freed.
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("lookahead_slots", [1, 16, 129])
|
||||
@pytest.mark.parametrize("appender", ["forked", "original"])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_cow_lookahead_simple(block_size: int, sequence_len: int,
|
||||
append_len: int, lookahead_slots: int,
|
||||
allocator_type: str, appender: str):
|
||||
"""Similar to test_cow, except with lookahead allocation. The assertions are
|
||||
less rigorous due to the complexity of the property under test.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
original_block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Allocate lookahead slots.
|
||||
original_block_table.ensure_num_empty_slots(lookahead_slots)
|
||||
original_block_ids = original_block_table.physical_block_ids
|
||||
|
||||
forked_block_table = original_block_table.fork()
|
||||
|
||||
if appender == "forked":
|
||||
appender_block_table = forked_block_table
|
||||
static_block_table = original_block_table
|
||||
elif appender == "original":
|
||||
appender_block_table = original_block_table
|
||||
static_block_table = forked_block_table
|
||||
else:
|
||||
raise ValueError(f"unknown test config {appender=}")
|
||||
|
||||
# Write tokens.
|
||||
appender_block_table.append_token_ids(token_ids_to_append)
|
||||
|
||||
# Expect the non-appending block table to have no change.
|
||||
assert static_block_table.physical_block_ids == original_block_ids
|
||||
assert appender_block_table.physical_block_ids != original_block_ids
|
||||
|
||||
cows = allocator.clear_copy_on_writes()
|
||||
|
||||
# Always expect copy-on-write
|
||||
assert cows
|
||||
|
||||
if sequence_len % block_size > 0:
|
||||
# If the last block in the sequence is not full, then when appending we
|
||||
# expect a CoW.
|
||||
assert cows
|
||||
|
||||
cow_block_id = sequence_len // block_size
|
||||
expected_src = static_block_table.physical_block_ids[cow_block_id]
|
||||
expected_dst = appender_block_table.physical_block_ids[cow_block_id]
|
||||
|
||||
assert expected_src in cows
|
||||
assert expected_dst in cows[expected_src]
|
||||
|
||||
static_block_table.free()
|
||||
appender_block_table.free()
|
||||
|
||||
# After free, expect all blocks to be freed.
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
42
tests/core/block/test_common.py
Normal file
42
tests/core/block/test_common.py
Normal file
@ -0,0 +1,42 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.common import RefCounter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
@pytest.mark.parametrize("num_incrs", [1, 100])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
def test_incr(seed: int, num_incrs: int, num_blocks: int):
|
||||
random.seed(seed)
|
||||
|
||||
all_block_indices = list(range(num_blocks))
|
||||
counter = RefCounter(all_block_indices=all_block_indices)
|
||||
|
||||
block_id = random.randint(0, num_blocks - 1)
|
||||
for i in range(num_incrs):
|
||||
value = counter.incr(block_id)
|
||||
assert value == i + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
@pytest.mark.parametrize("num_incrs", [1, 100])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
def test_incr_decr(seed: int, num_incrs: int, num_blocks: int):
|
||||
random.seed(seed)
|
||||
|
||||
all_block_indices = list(range(num_blocks))
|
||||
counter = RefCounter(all_block_indices=all_block_indices)
|
||||
|
||||
block_id = random.randint(0, num_blocks - 1)
|
||||
for i in range(num_incrs):
|
||||
value = counter.incr(block_id)
|
||||
assert value == i + 1
|
||||
|
||||
for i in range(num_incrs):
|
||||
value = counter.decr(block_id)
|
||||
assert value == num_incrs - (i + 1)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
counter.decr(block_id)
|
93
tests/core/block/test_cpu_gpu_block_allocator.py
Normal file
93
tests/core/block/test_cpu_gpu_block_allocator.py
Normal file
@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.utils import Device, chunk_list
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
|
||||
block_size: int, allocator_type: str):
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
cpu_blocks = [
|
||||
allocator.allocate_mutable(prev_block=None, device=Device.CPU)
|
||||
for _ in range(num_cpu_blocks)
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
gpu_blocks = [
|
||||
allocator.allocate_mutable(prev_block=None, device=Device.GPU)
|
||||
for _ in range(num_gpu_blocks)
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in cpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in gpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [2])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
|
||||
block_size: int, allocator_type: str):
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
unique_token_ids = list(
|
||||
range((num_cpu_blocks + num_gpu_blocks) * block_size))
|
||||
gpu_token_ids = chunk_list(unique_token_ids[:num_gpu_blocks * block_size],
|
||||
block_size)
|
||||
cpu_token_ids = chunk_list(unique_token_ids[num_gpu_blocks * block_size:],
|
||||
block_size)
|
||||
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
cpu_blocks = [
|
||||
allocator.allocate_immutable(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=Device.CPU)
|
||||
for token_ids in cpu_token_ids
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
gpu_blocks = [
|
||||
allocator.allocate_immutable(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=Device.GPU)
|
||||
for token_ids in gpu_token_ids
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in cpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in gpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
102
tests/core/block/test_naive_block.py
Normal file
102
tests/core/block/test_naive_block.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
|
||||
|
||||
class TestNaiveBlockAllocator:
|
||||
|
||||
@staticmethod
|
||||
def create_allocate_lambda(allocate_type: str,
|
||||
allocator: NaiveBlockAllocator,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int]):
|
||||
if allocate_type == "immutable":
|
||||
allocate_block = lambda: allocator.allocate_immutable(
|
||||
prev_block=prev_block, token_ids=token_ids)
|
||||
elif allocate_type == "mutable":
|
||||
allocate_block = lambda: allocator.allocate_mutable(prev_block=
|
||||
prev_block)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
return allocate_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_ooms(allocate_type: str, num_blocks: int,
|
||||
block_size: int):
|
||||
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
allocate_type,
|
||||
allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
[allocate_block() for _ in range(num_blocks)]
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_free_prevents_oom(allocate_type: str, num_blocks: int,
|
||||
block_size: int):
|
||||
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
allocate_type,
|
||||
allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
blocks = [allocate_block() for _ in range(num_blocks)]
|
||||
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
block_to_free = blocks.pop()
|
||||
|
||||
for _ in range(100):
|
||||
block_id = block_to_free.block_id
|
||||
allocator.free(block_to_free)
|
||||
assert block_to_free.block_id is None
|
||||
|
||||
new_block = allocate_block()
|
||||
assert new_block.block_id == block_id
|
||||
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
block_to_free = new_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
|
||||
block_size: int):
|
||||
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
allocate_type,
|
||||
allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
assert allocator.get_num_free_blocks() == num_blocks
|
||||
|
||||
blocks = [allocate_block() for _ in range(num_blocks)]
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
assert allocator.get_num_free_blocks() == i
|
||||
allocator.free(block)
|
384
tests/core/block/test_prefix_caching_block.py
Normal file
384
tests/core/block/test_prefix_caching_block.py
Normal file
@ -0,0 +1,384 @@
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
|
||||
PrefixCachingBlockAllocator)
|
||||
|
||||
|
||||
class TestPrefixCachingBlock:
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
@pytest.mark.parametrize("is_curr_block_full", [True, False])
|
||||
def test_first_block_has_correct_content_hash(seed: int, block_size: int,
|
||||
is_curr_block_full: bool):
|
||||
"""Verify a block which is first in the sequence has the correct hash.
|
||||
"""
|
||||
random.seed(seed)
|
||||
num_to_fill = block_size if is_curr_block_full else random.randint(
|
||||
0, block_size - 1)
|
||||
token_ids = list(range(num_to_fill))
|
||||
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
block_with_prev = PrefixCachingBlock(
|
||||
prev_block=None,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
prefix_caching_allocator=mock_allocator)
|
||||
|
||||
if is_curr_block_full:
|
||||
# Expect hash since block is full.
|
||||
assert block_with_prev.content_hash == (
|
||||
PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block=True,
|
||||
prev_block_hash=None,
|
||||
cur_block_token_ids=token_ids))
|
||||
else:
|
||||
# Do not expect hash since block is not full.
|
||||
assert block_with_prev.content_hash is None
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
@pytest.mark.parametrize("is_curr_block_full", [True, False])
|
||||
@pytest.mark.parametrize("prev_block_has_hash", [True, False])
|
||||
def test_nth_block_has_correct_content_hash(seed: int, block_size: int,
|
||||
is_curr_block_full: bool,
|
||||
prev_block_has_hash: bool):
|
||||
"""Verify a block which is not first in the sequence has the correct
|
||||
hash.
|
||||
"""
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
previous_block = MagicMock(spec=PrefixCachingBlock)
|
||||
prev_block_hash = random.randint(0, 1000)
|
||||
previous_block.content_hash = (prev_block_hash
|
||||
if prev_block_has_hash else None)
|
||||
|
||||
num_to_fill = block_size if is_curr_block_full else random.randint(
|
||||
0, block_size - 1)
|
||||
token_ids = list(range(num_to_fill))
|
||||
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
block_with_prev = PrefixCachingBlock(
|
||||
prev_block=previous_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
prefix_caching_allocator=mock_allocator,
|
||||
)
|
||||
|
||||
if is_curr_block_full and prev_block_has_hash:
|
||||
# Expect hash since block is full and previous block has hash.
|
||||
assert (block_with_prev.content_hash ==
|
||||
PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block=False,
|
||||
prev_block_hash=prev_block_hash,
|
||||
cur_block_token_ids=token_ids))
|
||||
else:
|
||||
# Do not expect hash since block is not full or the previous block
|
||||
# does not have a hash.
|
||||
assert block_with_prev.content_hash is None
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("block_size", [1, 2, 16])
|
||||
@pytest.mark.parametrize("num_tokens", list(range(3)))
|
||||
@pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10])
|
||||
def test_blocks_have_correct_hash_in_chain(block_size: int,
|
||||
num_tokens: int,
|
||||
num_empty_trailing_blocks: int):
|
||||
"""Create two chains of logical blocks with the same contents.
|
||||
Assert the hashes are equal.
|
||||
"""
|
||||
random.seed(0)
|
||||
|
||||
token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)]
|
||||
|
||||
first_chain, second_chain = [
|
||||
TestPrefixCachingBlock.create_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
num_empty_trailing_blocks=num_empty_trailing_blocks)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
for first_chain_block, second_chain_block in zip(
|
||||
first_chain, second_chain):
|
||||
assert (first_chain_block.content_hash ==
|
||||
second_chain_block.content_hash)
|
||||
|
||||
if not first_chain or not second_chain:
|
||||
assert first_chain == second_chain
|
||||
assert num_tokens == 0
|
||||
|
||||
@staticmethod
|
||||
def create_chain(block_size: int,
|
||||
token_ids: List[int],
|
||||
num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]:
|
||||
"""Helper method which creates a chain of blocks.
|
||||
"""
|
||||
blocks = []
|
||||
num_blocks = math.ceil(
|
||||
len(token_ids) / block_size) + num_empty_trailing_blocks
|
||||
|
||||
if num_blocks == 0:
|
||||
return []
|
||||
|
||||
allocator = MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
prev_block = None
|
||||
for block_number in range(0, num_blocks):
|
||||
prev_block = PrefixCachingBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=block_size,
|
||||
prefix_caching_allocator=allocator,
|
||||
)
|
||||
|
||||
tokens_to_append = token_ids[block_number *
|
||||
block_size:(block_number + 1) *
|
||||
block_size]
|
||||
if tokens_to_append:
|
||||
prev_block.append_token_ids(tokens_to_append)
|
||||
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
class TestPrefixCachingBlockAllocator:
|
||||
|
||||
@staticmethod
|
||||
def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int]):
|
||||
if allocate_type == "immutable":
|
||||
allocate_block = lambda: allocator.allocate_immutable(
|
||||
prev_block=prev_block, token_ids=token_ids)
|
||||
elif allocate_type == "mutable":
|
||||
allocate_block = lambda: allocator.allocate_mutable(prev_block=
|
||||
prev_block)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
return allocate_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_mutable_ooms(num_blocks: int, block_size: int):
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
|
||||
allocate_type="mutable",
|
||||
allocator=allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)),
|
||||
)
|
||||
|
||||
[allocate_block() for _ in range(num_blocks)]
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_immutable_does_not_oom_single_hash(
|
||||
num_blocks: int, block_size: int):
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
|
||||
allocate_type="immutable",
|
||||
allocator=allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)),
|
||||
)
|
||||
|
||||
blocks = [allocate_block() for _ in range(num_blocks)]
|
||||
|
||||
# Expect no OOM. If these were mutable blocks, this would OOM.
|
||||
non_oom_block = allocate_block()
|
||||
|
||||
# Expect all blocks to have same physical block index.
|
||||
for block in blocks:
|
||||
assert (block.block_id == non_oom_block.block_id)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_immutable_ooms_many_hash(num_blocks: int,
|
||||
block_size: int):
|
||||
"""Consume all blocks using many different hashes/block content.
|
||||
|
||||
Do this by creating a sequence that is very long.
|
||||
Expect next block to OOM.
|
||||
"""
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks * block_size))
|
||||
|
||||
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Expect allocation with unseen hash to fail.
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_immutable(prev_block=chain[-1],
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
# Expect mutable allocation to fail.
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_mutable(prev_block=chain[-1])
|
||||
|
||||
# Expect allocation of exact same chain to pass.
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Expect physical block indices to be the same in both chains.
|
||||
assert chain and second_chain
|
||||
for first_chain_block, second_chain_block in zip(chain, second_chain):
|
||||
assert (first_chain_block.block_id == second_chain_block.block_id)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_free_prevents_oom(num_blocks: int, block_size: int):
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks * block_size))
|
||||
|
||||
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Expect mutable allocation to fail.
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_mutable(prev_block=None)
|
||||
|
||||
block_to_free = chain[-1]
|
||||
|
||||
# Expect free/allocate loop to succeed many times.
|
||||
for i in range(100):
|
||||
block_id = block_to_free.block_id
|
||||
allocator.free(block_to_free)
|
||||
assert block_to_free.block_id is None, i
|
||||
|
||||
new_block = allocator.allocate_mutable(prev_block=None)
|
||||
assert new_block.block_id == block_id, i
|
||||
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_mutable(prev_block=None)
|
||||
|
||||
block_to_free = new_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int):
|
||||
random.seed(seed)
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
num_blocks_to_consume = random.randint(1, num_blocks - 1)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||
|
||||
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in chain, assert num free blocks includes new free
|
||||
# block.
|
||||
for i, block in enumerate(chain):
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_to_consume +
|
||||
i)
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
def test_get_num_free_blocks_shared(num_blocks: int, block_size: int,
|
||||
seed: int):
|
||||
"""Verify sharing occurs by allocating two sequences that share prefixes
|
||||
and incrementally freeing blocks.
|
||||
"""
|
||||
random.seed(seed)
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
num_blocks_to_consume = random.randint(1, num_blocks - 1)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in the first chain. Since all blocks are shared, the
|
||||
# free count should stay constant.
|
||||
for i, block in enumerate(first_chain):
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_to_consume)
|
||||
allocator.free(block)
|
||||
|
||||
# Free each block in the second chain. Since the refcount is now zero,
|
||||
# the free count should increment with each free.
|
||||
for i, block in enumerate(second_chain):
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_to_consume +
|
||||
i)
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
token_ids: List[int],
|
||||
allocator: PrefixCachingBlockAllocator,
|
||||
) -> List[PrefixCachingBlock]:
|
||||
"""Helper method which creates a chain of blocks.
|
||||
"""
|
||||
blocks = []
|
||||
num_blocks = math.ceil(len(token_ids) / block_size)
|
||||
|
||||
if num_blocks == 0:
|
||||
return []
|
||||
|
||||
prev_block = None
|
||||
for block_number in range(0, num_blocks):
|
||||
block_token_ids = token_ids[block_number *
|
||||
block_size:(block_number + 1) *
|
||||
block_size]
|
||||
prev_block = allocator.allocate_immutable(
|
||||
prev_block=prev_block, token_ids=block_token_ids)
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
@ -5,8 +5,9 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.core.block_manager import (AllocStatus, BlockSpaceManager,
|
||||
UncachedBlockAllocator)
|
||||
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
|
||||
UncachedBlockAllocator)
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -63,10 +64,10 @@ def test_allocate():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
for i in range(num_gpu_blocks):
|
||||
@ -77,10 +78,10 @@ def test_allocate():
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
# Use watermark to reserve one gpu block.
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=1 / num_gpu_blocks)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=1 / num_gpu_blocks)
|
||||
for i in range(num_gpu_blocks - 1):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
@ -92,10 +93,10 @@ def test_append_slot_single_seq():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate single seq to gpu block.
|
||||
prompt, seq_group = create_dummy_prompt("1", block_size)
|
||||
@ -124,10 +125,10 @@ def test_append_slot_cow():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size=block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size=block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate prompt to gpu block. There is one slot left in the block.
|
||||
prompt = Sequence(seq_id=1,
|
||||
@ -165,10 +166,10 @@ def test_fork():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1",
|
||||
block_size - 1,
|
||||
@ -192,10 +193,10 @@ def test_swap():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
|
||||
prompt.status = SequenceStatus.WAITING
|
||||
@ -238,10 +239,10 @@ def test_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1", block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
@ -262,10 +263,10 @@ def test_reset():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same seq group on all available gpu blocks.
|
||||
original_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
@ -289,11 +290,11 @@ def test_sliding_window_multi_seq():
|
||||
num_cpu_blocks = 8
|
||||
num_gpu_blocks = 8
|
||||
sliding_window = 2
|
||||
block_manager = BlockSpaceManager(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
sliding_window=sliding_window,
|
||||
watermark=0)
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
sliding_window=sliding_window,
|
||||
watermark=0)
|
||||
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
|
@ -13,7 +13,7 @@ from .utils import create_dummy_prompt
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
@ -2,7 +2,7 @@ import time
|
||||
from typing import Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
@ -23,5 +23,42 @@ def create_dummy_prompt(
|
||||
return prompt, seq_group
|
||||
|
||||
|
||||
def create_seq_group(
|
||||
seq_prompt_lens=1024,
|
||||
seq_output_lens=(128, ),
|
||||
request_id='0',
|
||||
seq_id_start=0,
|
||||
) -> SequenceGroup:
|
||||
|
||||
assert len(seq_output_lens) > 0
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_lens
|
||||
|
||||
seqs = []
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
prompt="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
seq.append_token_id(
|
||||
token_id=i,
|
||||
logprobs={i: Logprob(0.0)},
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=seqs,
|
||||
sampling_params=SamplingParams(),
|
||||
arrival_time=time.time(),
|
||||
)
|
||||
|
||||
return seq_group
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.core.block_manager import CachedBlockAllocator
|
||||
from vllm.core.block_manager_v1 import CachedBlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
|
@ -324,6 +324,8 @@ class CacheConfig:
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
cache_dtype: Data type for kv cache storage.
|
||||
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
|
||||
profiled num_gpu_blocks if specified. Does nothing if None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -332,12 +334,14 @@ class CacheConfig:
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: int,
|
||||
cache_dtype: str,
|
||||
forced_num_gpu_blocks: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self.forced_num_gpu_blocks = forced_num_gpu_blocks
|
||||
self.cache_dtype = cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
@ -528,6 +532,7 @@ class SchedulerConfig:
|
||||
and generated text).
|
||||
delay_factor: Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt.
|
||||
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -535,6 +540,7 @@ class SchedulerConfig:
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
use_v2_block_manager: bool = False,
|
||||
delay_factor: float = 0.0,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
@ -546,6 +552,7 @@ class SchedulerConfig:
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.delay_factor = delay_factor
|
||||
self.use_v2_block_manager = use_v2_block_manager
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
245
vllm/core/block/block_table.py
Normal file
245
vllm/core/block/block_table.py
Normal file
@ -0,0 +1,245 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
class BlockTable:
|
||||
"""A class to manage blocks for a specific sequence.
|
||||
|
||||
The BlockTable maps a sequence of tokens to a list of blocks, where each
|
||||
block represents a contiguous memory allocation for a portion of the
|
||||
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
|
||||
responsible for allocating and freeing memory for the blocks.
|
||||
|
||||
Args:
|
||||
block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]], optional): An optional list of existing
|
||||
blocks to initialize the BlockTable with. If not provided, an empty
|
||||
BlockTable is created.
|
||||
|
||||
Attributes:
|
||||
_block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]]): The list of blocks managed by this
|
||||
BlockTable.
|
||||
_num_full_slots (int): The number of tokens currently stored in the
|
||||
blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
block_allocator: DeviceAwareBlockAllocator,
|
||||
_blocks: Optional[List[Block]] = None,
|
||||
):
|
||||
self._block_size = block_size
|
||||
self._allocator = block_allocator
|
||||
self._blocks: Optional[List[Block]] = _blocks
|
||||
|
||||
# Use helper method instead of directly calculating, as blocks
|
||||
# may not be allocated.
|
||||
self._num_full_slots = len(self._get_all_token_ids())
|
||||
|
||||
@staticmethod
|
||||
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
|
||||
"""Calculates the minimum number of blocks required to store a given
|
||||
sequence of token IDs.
|
||||
|
||||
This assumes worst-case scenario, where every block requires a new
|
||||
allocation (e.g. ignoring prefix caching).
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
block_size (int): The maximum number of tokens that can be stored in
|
||||
a single block.
|
||||
|
||||
Returns:
|
||||
int: The minimum number of blocks required to store the given
|
||||
sequence of token IDs.
|
||||
"""
|
||||
return cdiv(len(token_ids), block_size)
|
||||
|
||||
def allocate(self,
|
||||
token_ids: List[int],
|
||||
device: Device = Device.GPU) -> None:
|
||||
"""Allocates memory blocks for storing the given sequence of token IDs.
|
||||
|
||||
This method allocates the required number of blocks to store the given
|
||||
sequence of token IDs.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
device (Device, optional): The device on which the blocks should be
|
||||
allocated. Defaults to Device.GPU.
|
||||
"""
|
||||
assert not self._is_allocated
|
||||
assert token_ids
|
||||
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=device)
|
||||
self._num_full_slots = len(token_ids)
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
This method appends the given sequence of token IDs to the existing
|
||||
blocks in the BlockTable. If there is not enough space in the existing
|
||||
blocks, new blocks are allocated using the `ensure_num_empty_slots`
|
||||
method to accommodate the additional tokens.
|
||||
|
||||
The token IDs are divided into chunks of size `block_size` (except for
|
||||
the first chunk, which may be smaller), and each chunk is appended to a
|
||||
separate block.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids))
|
||||
|
||||
blocks = self._blocks[self._num_full_slots // self._block_size:]
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
|
||||
token_ids[first_chunk_size:], self._block_size)
|
||||
|
||||
for block, token_block in zip(blocks, token_blocks):
|
||||
block.append_token_ids(token_block)
|
||||
|
||||
self._num_full_slots += len(token_ids)
|
||||
|
||||
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
|
||||
"""Ensures that the BlockTable has at least the specified number of
|
||||
empty slots available.
|
||||
|
||||
This method checks if the BlockTable has enough empty slots (i.e.,
|
||||
available space) to accommodate the requested number of tokens. If not,
|
||||
it allocates additional blocks on the GPU to ensure that the required
|
||||
number of empty slots is available.
|
||||
|
||||
Args:
|
||||
num_empty_slots (int): The minimum number of empty slots required.
|
||||
"""
|
||||
# Currently the block table only supports
|
||||
# appending tokens to GPU blocks.
|
||||
device = Device.GPU
|
||||
assert self._is_allocated
|
||||
|
||||
if self._num_empty_slots >= num_empty_slots:
|
||||
return
|
||||
|
||||
slots_to_allocate = num_empty_slots - self._num_empty_slots
|
||||
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
|
||||
|
||||
for _ in range(blocks_to_allocate):
|
||||
self._blocks.append(
|
||||
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
|
||||
device=device))
|
||||
|
||||
def fork(self) -> "BlockTable":
|
||||
"""Creates a new BlockTable instance with a copy of the blocks from the
|
||||
current instance.
|
||||
|
||||
This method creates a new BlockTable instance with the same block size,
|
||||
block allocator, and a copy of the blocks from the current instance. The
|
||||
new BlockTable has its own independent set of blocks, but shares the
|
||||
same underlying memory allocation with the original BlockTable.
|
||||
|
||||
Returns:
|
||||
BlockTable: A new BlockTable instance with a copy of the blocks from
|
||||
the current instance.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
forked_blocks = self._allocator.fork(self._blocks[-1])
|
||||
return BlockTable(
|
||||
block_size=self._block_size,
|
||||
block_allocator=self._allocator,
|
||||
_blocks=forked_blocks,
|
||||
)
|
||||
|
||||
def free(self) -> None:
|
||||
"""Frees the memory occupied by the blocks in the BlockTable.
|
||||
|
||||
This method iterates over all the blocks in the `_blocks` list and calls
|
||||
the `free` method of the `_allocator` object to release the memory
|
||||
occupied by each block. After freeing all the blocks, the `_blocks` list
|
||||
is set to `None`.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
for block in self._blocks:
|
||||
self._allocator.free(block)
|
||||
self._blocks = None
|
||||
|
||||
@property
|
||||
def physical_block_ids(self) -> List[int]:
|
||||
"""Returns a list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
|
||||
This property returns a list of integers, where each integer represents
|
||||
the physical block index of a corresponding block in the `_blocks` list.
|
||||
The physical block index is a unique identifier for the memory location
|
||||
occupied by the block.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
return [block.block_id for block in self._blocks]
|
||||
|
||||
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device) -> List[Block]:
|
||||
blocks = []
|
||||
for block_token_ids in chunk_list(token_ids, self._block_size):
|
||||
if len(block_token_ids) == self._block_size:
|
||||
# If the block is full, create an immutable block.
|
||||
prev_block = self._allocator.allocate_immutable(
|
||||
prev_block, token_ids=block_token_ids, device=device)
|
||||
else:
|
||||
# Else, partially fill a mutable block with token ids.
|
||||
prev_block = self._allocator.allocate_mutable(
|
||||
prev_block=prev_block, device=device)
|
||||
prev_block.append_token_ids(block_token_ids)
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_all_token_ids(self) -> List[int]:
|
||||
# NOTE: This function is O(seq_len); use sparingly.
|
||||
token_ids = []
|
||||
|
||||
if not self._is_allocated:
|
||||
return token_ids
|
||||
|
||||
for block in self._blocks:
|
||||
token_ids.extend(block.token_ids)
|
||||
|
||||
return token_ids
|
||||
|
||||
@property
|
||||
def _is_allocated(self) -> bool:
|
||||
return self._blocks is not None
|
||||
|
||||
@property
|
||||
def _num_empty_slots(self) -> int:
|
||||
assert self._is_allocated
|
||||
return len(self._blocks) * self._block_size - self._num_full_slots
|
||||
|
||||
@property
|
||||
def num_full_slots(self) -> int:
|
||||
"""Returns the total number of tokens currently stored in the
|
||||
BlockTable.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens currently stored in the BlockTable.
|
||||
"""
|
||||
return self._num_full_slots
|
185
vllm/core/block/common.py
Normal file
185
vllm/core/block/common.py
Normal file
@ -0,0 +1,185 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
|
||||
BlockId = int
|
||||
RefCount = int
|
||||
|
||||
|
||||
class RefCounter:
|
||||
"""A class for managing reference counts for a set of block indices.
|
||||
|
||||
The RefCounter class maintains a dictionary that maps block indices to their
|
||||
corresponding reference counts. It provides methods to increment, decrement,
|
||||
and retrieve the reference count for a given block index.
|
||||
|
||||
Args:
|
||||
all_block_indices (Iterable[BlockId]): An iterable of block indices
|
||||
to initialize the reference counter with.
|
||||
"""
|
||||
|
||||
def __init__(self, all_block_indices: Iterable[BlockId]):
|
||||
deduped = set(all_block_indices)
|
||||
self._refcounts: Dict[BlockId,
|
||||
RefCount] = {index: 0
|
||||
for index in deduped}
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
pre_incr_refcount = self._refcounts[block_id]
|
||||
|
||||
assert pre_incr_refcount >= 0
|
||||
|
||||
post_incr_refcount = pre_incr_refcount + 1
|
||||
self._refcounts[block_id] = post_incr_refcount
|
||||
return post_incr_refcount
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
refcount = self._refcounts[block_id]
|
||||
|
||||
assert refcount > 0
|
||||
refcount -= 1
|
||||
|
||||
self._refcounts[block_id] = refcount
|
||||
|
||||
return refcount
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
return self._refcounts[block_id]
|
||||
|
||||
def as_readonly(self) -> "ReadOnlyRefCounter":
|
||||
return ReadOnlyRefCounter(self)
|
||||
|
||||
|
||||
class ReadOnlyRefCounter:
|
||||
"""A read-only view of the RefCounter class.
|
||||
|
||||
The ReadOnlyRefCounter class provides a read-only interface to access the
|
||||
reference counts maintained by a RefCounter instance. It does not allow
|
||||
modifications to the reference counts.
|
||||
|
||||
Args:
|
||||
refcounter (RefCounter): The RefCounter instance to create a read-only
|
||||
view for.
|
||||
"""
|
||||
|
||||
def __init__(self, refcounter: RefCounter):
|
||||
self._refcounter = refcounter
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
raise ValueError("Incr not allowed")
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
raise ValueError("Decr not allowed")
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
return self._refcounter.get(block_id)
|
||||
|
||||
|
||||
class CopyOnWriteTracker:
|
||||
"""A class for tracking and managing copy-on-write operations for blocks.
|
||||
|
||||
The CopyOnWriteTracker class maintains a mapping of source block indices to
|
||||
their corresponding copy-on-write destination block indices. It works in
|
||||
conjunction with a RefCounter and a BlockAllocator to handle reference
|
||||
counting and block allocation.
|
||||
|
||||
Args:
|
||||
refcounter (RefCounter): The reference counter used to track block
|
||||
reference counts.
|
||||
allocator (BlockAllocator): The block allocator used to allocate and
|
||||
free blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
refcounter: RefCounter,
|
||||
allocator: BlockAllocator,
|
||||
):
|
||||
self._copy_on_writes = defaultdict(list)
|
||||
self._refcounter = refcounter
|
||||
self._allocator = allocator
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
This method checks the reference count of the given block. If the
|
||||
reference count is greater than 1, indicating that the block is shared,
|
||||
a copy-on-write operation is performed. The original block is freed,
|
||||
and a new block is allocated with the same content. The new block index
|
||||
is returned.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
Optional[BlockId]: The block index of the new block if a copy-on
|
||||
-write operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
block_id = block.block_id
|
||||
if block_id is None:
|
||||
return block_id
|
||||
|
||||
refcount = self._refcounter.get(block_id)
|
||||
assert refcount != 0
|
||||
if refcount > 1:
|
||||
src_block_id = block_id
|
||||
|
||||
# Decrement refcount of the old block.
|
||||
self._allocator.free(block)
|
||||
|
||||
# Allocate a fresh new block.
|
||||
block_id = self._allocator.allocate_mutable(
|
||||
prev_block=block.prev_block).block_id
|
||||
|
||||
# Track src/dst copy.
|
||||
self._copy_on_writes[src_block_id].append(block_id)
|
||||
|
||||
return block_id
|
||||
|
||||
def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
|
||||
"""Clears the copy-on-write tracking information and returns the current
|
||||
state.
|
||||
|
||||
This method returns a dictionary mapping source block indices to lists
|
||||
of destination block indices for the current copy-on-write operations.
|
||||
It then clears the internal tracking information.
|
||||
|
||||
Returns:
|
||||
Dict[BlockId, List[BlockId]]: A dictionary mapping source
|
||||
block indices to lists of destination block indices for the
|
||||
current copy-on-write operations.
|
||||
"""
|
||||
cows = dict(self._copy_on_writes)
|
||||
self._copy_on_writes.clear()
|
||||
return cows
|
||||
|
||||
|
||||
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
|
||||
"""Retrieves all the blocks in a sequence starting from the last block.
|
||||
|
||||
This function recursively traverses the sequence of blocks in reverse order,
|
||||
starting from the given last block, and returns a list of all the blocks in
|
||||
the sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A list of all the blocks in the sequence, in the order they
|
||||
appear.
|
||||
"""
|
||||
|
||||
def recurse(block: Block, lst: List[Block]) -> None:
|
||||
if block.prev_block is not None:
|
||||
recurse(block.prev_block, lst)
|
||||
lst.append(block)
|
||||
|
||||
all_blocks = []
|
||||
recurse(last_block, all_blocks)
|
||||
return all_blocks
|
206
vllm/core/block/cpu_gpu_block_allocator.py
Normal file
206
vllm/core/block/cpu_gpu_block_allocator.py
Normal file
@ -0,0 +1,206 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
"""A block allocator that can allocate blocks on both CPU and GPU memory.
|
||||
|
||||
This class implements the `DeviceAwareBlockAllocator` interface and provides
|
||||
functionality for allocating and managing blocks of memory on both CPU and
|
||||
GPU devices.
|
||||
|
||||
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
|
||||
blocks, and allows for allocation, deallocation, forking, and swapping of
|
||||
blocks across these memory pools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
allocator_type: str,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
) -> DeviceAwareBlockAllocator:
|
||||
"""Creates a CpuGpuBlockAllocator instance with the specified
|
||||
configuration.
|
||||
|
||||
This static method creates and returns a CpuGpuBlockAllocator instance
|
||||
based on the provided parameters. It initializes the CPU and GPU block
|
||||
allocators with the specified number of blocks, block size, and
|
||||
allocator type.
|
||||
|
||||
Args:
|
||||
allocator_type (str): The type of block allocator to use for CPU
|
||||
and GPU blocks. Currently supported values are "naive" and
|
||||
"prefix_caching".
|
||||
num_gpu_blocks (int): The number of blocks to allocate for GPU
|
||||
memory.
|
||||
num_cpu_blocks (int): The number of blocks to allocate for CPU
|
||||
memory.
|
||||
block_size (int): The size of each block in number of tokens.
|
||||
|
||||
Returns:
|
||||
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
|
||||
specified configuration.
|
||||
|
||||
Notes:
|
||||
- The block IDs are assigned contiguously, with GPU block IDs coming
|
||||
before CPU block IDs.
|
||||
"""
|
||||
block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
|
||||
gpu_block_ids = block_ids[:num_gpu_blocks]
|
||||
cpu_block_ids = block_ids[num_gpu_blocks:]
|
||||
|
||||
if allocator_type == "naive":
|
||||
gpu_allocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock,
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock,
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
elif allocator_type == "prefix_caching":
|
||||
gpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown allocator type {allocator_type=}")
|
||||
|
||||
return CpuGpuBlockAllocator(
|
||||
cpu_block_allocator=cpu_allocator,
|
||||
gpu_block_allocator=gpu_allocator,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_block_allocator: BlockAllocator,
|
||||
gpu_block_allocator: BlockAllocator,
|
||||
):
|
||||
assert not (
|
||||
cpu_block_allocator.all_block_ids
|
||||
& gpu_block_allocator.all_block_ids
|
||||
), "cpu and gpu block allocators can't have intersection of block ids"
|
||||
|
||||
self._allocators = {
|
||||
Device.CPU: cpu_block_allocator,
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
|
||||
self._block_ids_to_allocator = {}
|
||||
for _, allocator in self._allocators.items():
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block to in the sequence.
|
||||
Used for prefix hashing.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
return self._allocators[device].allocate_mutable(prev_block)
|
||||
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
"""Allocates a new immutable block with the provided token IDs on the
|
||||
specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
token_ids (List[int]): The list of token IDs to be stored in the new
|
||||
block.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block containing the provided
|
||||
token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable(
|
||||
prev_block, token_ids)
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
"""Frees the memory occupied by the given block.
|
||||
|
||||
Args:
|
||||
block (Block): The block to be freed.
|
||||
"""
|
||||
allocator = self._block_ids_to_allocator[block.block_id]
|
||||
return allocator.free(block)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A new list of blocks that shares the same memory as the
|
||||
original sequence.
|
||||
"""
|
||||
allocator = self._block_ids_to_allocator[last_block.block_id]
|
||||
return allocator.fork(last_block)
|
||||
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
"""Returns the number of free blocks available on the specified device.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query the number of free
|
||||
blocks.
|
||||
|
||||
Returns:
|
||||
int: The number of free blocks available on the specified device.
|
||||
"""
|
||||
return self._allocators[device].get_num_free_blocks()
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||
"""Clears the copy-on-write (CoW) state and returns the mapping of
|
||||
source to destination block IDs.
|
||||
|
||||
Returns:
|
||||
Dict[int, List[int]]: A dictionary mapping source block IDs to lists
|
||||
of destination block IDs.
|
||||
"""
|
||||
# CoW only supported on GPU
|
||||
device = Device.GPU
|
||||
return self._allocators[device].clear_copy_on_writes()
|
||||
|
||||
def mark_blocks_as_computed(self) -> None:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_computed()
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].get_common_computed_block_ids(
|
||||
seq_block_ids)
|
||||
|
||||
def all_block_ids(self) -> frozenset[int]:
|
||||
return frozenset(self._block_ids_to_allocator.keys())
|
105
vllm/core/block/interfaces.py
Normal file
105
vllm/core/block/interfaces.py
Normal file
@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class Block(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def block_id(self) -> Optional[int]:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def token_ids(self) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def num_empty_slots(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def is_full(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
pass
|
||||
|
||||
class Factory(Protocol):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
prev_block: Optional["Block"],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: "BlockAllocator",
|
||||
block_id: Optional[int] = None,
|
||||
) -> "Block":
|
||||
pass
|
||||
|
||||
|
||||
class BlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: Block) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def all_block_ids(self) -> frozenset[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
pass
|
||||
|
||||
class NoFreeBlocksError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceAwareBlockAllocator(BlockAllocator):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
pass
|
275
vllm/core/block/naive_block.py
Normal file
275
vllm/core/block/naive_block.py
Normal file
@ -0,0 +1,275 @@
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
|
||||
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
|
||||
BlockId = int
|
||||
Refcount = int
|
||||
|
||||
|
||||
class NaiveBlockAllocator(BlockAllocator):
|
||||
"""A simple block allocator that manages blocks of memory without prefix
|
||||
caching.
|
||||
|
||||
Args:
|
||||
create_block (Block.Factory): A factory function for creating new
|
||||
blocks. This is used when a NaiveBlockAllocator is composed within
|
||||
a prefix caching allocator -- the naive block allocator must
|
||||
construct prefix caching blocks (but shouldn't know anything else
|
||||
about them).
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids (Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
create_block: Block.Factory,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._free_block_indices: Set[BlockId] = set(block_ids)
|
||||
self._all_block_indices = frozenset(block_ids)
|
||||
assert len(self._all_block_indices) == num_blocks
|
||||
|
||||
self._refcounter = RefCounter(
|
||||
all_block_indices=self._free_block_indices)
|
||||
self._create_block = create_block
|
||||
self._block_size = block_size
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly(),
|
||||
allocator=self,
|
||||
)
|
||||
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int]) -> Block:
|
||||
"""Allocates a new immutable block with the given token IDs, linked to
|
||||
the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block.
|
||||
"""
|
||||
block = self.allocate_mutable(prev_block=prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
return block
|
||||
|
||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||
"""Allocates a new mutable block, linked to the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
block_id = self._allocate_new_block_id()
|
||||
return self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_id=block_id,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
)
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
self._free_block_id(block.block_id)
|
||||
|
||||
# Mark the block as having no allocation.
|
||||
block.block_id = None
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
|
||||
# Increment refcount for each block.
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork free'd block"
|
||||
|
||||
forked_blocks.append(
|
||||
self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_id=block.block_id,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
))
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self._free_block_indices)
|
||||
|
||||
def _allocate_new_block_id(self) -> BlockId:
|
||||
if not self._free_block_indices:
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
block_id = next(iter(self._free_block_indices))
|
||||
self._refcounter.incr(block_id)
|
||||
self._free_block_indices.remove(block_id)
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block_id: BlockId) -> None:
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices.add(block_id)
|
||||
|
||||
@property
|
||||
def refcounter(self):
|
||||
return self._refcounter
|
||||
|
||||
@property
|
||||
def all_block_ids(self):
|
||||
return self._all_block_indices
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
Optional[BlockId]: The block index of the new block if a copy-on
|
||||
-write operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
return self._cow_tracker.cow_block_if_not_appendable(block)
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
Dict[BlockId, List[BlockId]]: A dictionary mapping source
|
||||
block indices to lists of destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_computed(self) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Determine blocks that can be skipped in prefill.
|
||||
|
||||
Since the naive allocator does not support prefix caching, always return
|
||||
an empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class NaiveBlock(Block):
|
||||
"""An implementation of the Block class that does not support prefix
|
||||
caching.
|
||||
|
||||
The NaiveBlock class represents a block of token IDs with a fixed size. It
|
||||
provides methods for appending token IDs to the block and manages copy-on
|
||||
-write operations when necessary.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
allocator (BlockAllocator): The block allocator associated with this
|
||||
block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None, which means no allocation has been
|
||||
made.
|
||||
_cow_target (Optional[Block], optional): The copy-on-write target block.
|
||||
If not provided, it defaults to self.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_block: Block,
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
_cow_target: Optional[Block] = None):
|
||||
self._token_ids = []
|
||||
self._block_size = block_size
|
||||
self._prev_block = prev_block
|
||||
self._block_id = block_id
|
||||
self._allocator = allocator
|
||||
self._cow_target = _cow_target if _cow_target is not None else self
|
||||
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
"""Appends the given token IDs to the block, instructing the allocator
|
||||
to perform a copy-on-write if necessary.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
if self._block_id is not None:
|
||||
self._block_id = (self._allocator.cow_block_if_not_appendable(
|
||||
self._cow_target))
|
||||
|
||||
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
|
||||
assert self.num_empty_slots >= len(token_ids)
|
||||
self._token_ids.extend(token_ids)
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value: Optional[int]) -> None:
|
||||
self._block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self.num_empty_slots == 0
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block_size - len(self._token_ids)
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._token_ids
|
||||
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
return self._prev_block
|
472
vllm/core/block/prefix_caching_block.py
Normal file
472
vllm/core/block/prefix_caching_block.py
Normal file
@ -0,0 +1,472 @@
|
||||
"""Token blocks."""
|
||||
from itertools import takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from vllm.core.block.common import (CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
|
||||
PrefixHash = int
|
||||
BlockId = int
|
||||
|
||||
|
||||
class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
"""A block allocator that implements prefix caching.
|
||||
|
||||
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
|
||||
content hash. It reuses blocks with the same content hash to avoid redundant
|
||||
memory allocation. The allocator also supports copy-on-write operations.
|
||||
|
||||
Args:
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids(Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
# TODO last access time / evictor integration
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
):
|
||||
# A mapping of prefix hash to block index. All blocks which have a
|
||||
# prefix hash will be in this dict, even if they have refcount 0.
|
||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||
|
||||
# A mapping of prefix hash to block index. All blocks which have a
|
||||
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
|
||||
# of self._cached_blocks.
|
||||
self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||
|
||||
# An allocator for blocks that do not have prefix hashes.
|
||||
self._hashless_allocator = NaiveBlockAllocator(
|
||||
create_block=self._create_block,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=block_ids,
|
||||
)
|
||||
|
||||
self._block_size = block_size
|
||||
|
||||
# We share the refcounter between allocators. This allows us to promote
|
||||
# blocks originally allocated in the hashless allocator to immutable
|
||||
# blocks.
|
||||
self._refcounter = self._hashless_allocator.refcounter
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly(),
|
||||
allocator=self,
|
||||
)
|
||||
|
||||
# Implements Block.Factory.
|
||||
def _create_block(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
) -> Block:
|
||||
# Bind block to self.
|
||||
allocator = self
|
||||
|
||||
return PrefixCachingBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
prefix_caching_allocator=allocator,
|
||||
)
|
||||
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int]) -> Block:
|
||||
"""Allocates an immutable block with the given token IDs, reusing cached
|
||||
blocks if possible.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the block.
|
||||
|
||||
Returns:
|
||||
Block: The allocated immutable block.
|
||||
"""
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
block = self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
)
|
||||
assert block.content_hash is not None
|
||||
|
||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||
if cached_block_id is not None:
|
||||
block.block_id = cached_block_id
|
||||
self._incr_refcount_cached_block(block.content_hash,
|
||||
block.block_id)
|
||||
return block
|
||||
|
||||
block = self.allocate_mutable(prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
assert block.content_hash is not None
|
||||
# TODO computed bit
|
||||
|
||||
return block
|
||||
|
||||
def allocate_mutable(self, prev_block: Block) -> Block:
|
||||
"""Allocates a mutable block. If there are no free blocks, this will
|
||||
evict unused cached blocks.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
|
||||
Returns:
|
||||
Block: The allocated mutable block.
|
||||
"""
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
try:
|
||||
return self._hashless_allocator.allocate_mutable(
|
||||
prev_block=prev_block)
|
||||
except BlockAllocator.NoFreeBlocksError:
|
||||
# We must check the unused cached blocks before raising OOM.
|
||||
pass
|
||||
|
||||
if self._unused_cached_blocks:
|
||||
# TODO policy for selecting block to remove
|
||||
content_hash_to_evict = next(iter(self._unused_cached_blocks))
|
||||
|
||||
# Clear content hash mapping; the block will be overwritten.
|
||||
del self._cached_blocks[content_hash_to_evict]
|
||||
|
||||
block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
assert refcount == 1
|
||||
block = self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
block_id=block_id,
|
||||
)
|
||||
assert block.content_hash is None
|
||||
return block
|
||||
|
||||
# No block available in hashless allocator, nor in unused cache blocks.
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
def _incr_refcount_cached_block(self, content_hash: int,
|
||||
block_id: BlockId) -> None:
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
if refcount == 1:
|
||||
assert content_hash in self._unused_cached_blocks
|
||||
del self._unused_cached_blocks[content_hash]
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
"""Decrement the refcount of the block. If the decremented refcount is
|
||||
zero, store the block in the freelist.
|
||||
|
||||
If the block has a content hash (meaning it is immutable), then we will
|
||||
keep the block around in case future allocations require it.
|
||||
"""
|
||||
assert (block.block_id
|
||||
is not None), "freeing unallocated block is undefined"
|
||||
|
||||
self._free_block_id_for_block(block.block_id, block)
|
||||
block.block_id = None
|
||||
|
||||
def _free_block_id_for_block(self, block_id: BlockId,
|
||||
block: Block) -> None:
|
||||
assert isinstance(block, PrefixCachingBlock)
|
||||
|
||||
if block.content_hash is None:
|
||||
return self._hashless_allocator.free(block)
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
|
||||
# If no longer used, add the block to the unused cached blocks.
|
||||
if refcount == 0:
|
||||
assert block.content_hash not in self._unused_cached_blocks
|
||||
assert block.content_hash in self._cached_blocks
|
||||
self._unused_cached_blocks[block.content_hash] = block_id
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork free'd block"
|
||||
|
||||
forked_blocks.append(
|
||||
self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_id=block.block_id,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
))
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
# The number of free blocks is the number of hashless free blocks
|
||||
# plus the number of hashful blocks that are unused.
|
||||
return self._hashless_allocator.get_num_free_blocks() + len(
|
||||
self._unused_cached_blocks)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> frozenset[int]:
|
||||
return self._hashless_allocator.all_block_ids
|
||||
|
||||
def promote_to_immutable_block(self,
|
||||
block: "PrefixCachingBlock") -> BlockId:
|
||||
"""Once a mutable block is full, it can be promoted to an immutable
|
||||
block. This means that its content can be referenced by future blocks
|
||||
having the same prefix.
|
||||
|
||||
Note that if we already have a cached block with the same content, we
|
||||
will replace the newly-promoted block's mapping with the existing cached
|
||||
block.
|
||||
|
||||
Args:
|
||||
block (PrefixCachingBlock): The mutable block to be promoted.
|
||||
|
||||
Returns:
|
||||
BlockId: Either the original block index, or the block index of
|
||||
the previously cached block matching the same content.
|
||||
"""
|
||||
assert block.content_hash is not None
|
||||
assert block.block_id is not None
|
||||
assert self._refcounter.get(block.block_id) > 0
|
||||
|
||||
# If the content hash does not have a corresponding cached block,
|
||||
# set this block as the cached block.
|
||||
if block.content_hash not in self._cached_blocks:
|
||||
self._cached_blocks[block.content_hash] = block.block_id
|
||||
else:
|
||||
self._free_block_id_for_block(block.block_id, block)
|
||||
self._incr_refcount_cached_block(
|
||||
block.content_hash, self._cached_blocks[block.content_hash])
|
||||
|
||||
return self._cached_blocks[block.content_hash]
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
Optional[BlockId]: The block index of the new block if a copy-on
|
||||
-write operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
return self._cow_tracker.cow_block_if_not_appendable(block)
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
Dict[BlockId, List[BlockId]]: A dictionary mapping source
|
||||
block indices to lists of destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_computed(self) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching."""
|
||||
# TODO Track computed blocks.
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Used in prefill (can skip prefill of some blocks).
|
||||
"""
|
||||
|
||||
# TODO: Track computed blocks.
|
||||
computed = lambda block_id: False
|
||||
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
ids_list = [
|
||||
takewhile(lambda block_id: computed(block_id), seq[:-1])
|
||||
for seq in seq_block_ids
|
||||
]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
|
||||
class PrefixCachingBlock(Block):
|
||||
"""A block implementation that supports prefix caching.
|
||||
|
||||
The PrefixCachingBlock class represents a block of token IDs with prefix
|
||||
caching capabilities. It wraps a NaiveBlock internally and provides
|
||||
additional functionality for content hashing and promoting immutable blocks
|
||||
with the prefix caching allocator.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[PrefixCachingBlock]): The previous block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
|
||||
caching block allocator associated with this block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prev_block: Optional["PrefixCachingBlock"],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
prefix_caching_allocator: PrefixCachingBlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
):
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
self._prev_block = prev_block
|
||||
self._cached_content_hash: Optional[int] = None
|
||||
self._prefix_caching_allocator = prefix_caching_allocator
|
||||
|
||||
self._block = NaiveBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=prefix_caching_allocator,
|
||||
_cow_target=self,
|
||||
)
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
"""Appends the given token IDs to the block and registers the block as
|
||||
immutable if the block becomes full.
|
||||
|
||||
Internally, the naive block handles CoW.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
assert token_ids
|
||||
|
||||
# naive block handles CoW.
|
||||
self._block.append_token_ids(token_ids)
|
||||
|
||||
# If the content hash is present, then the block can be made immutable.
|
||||
# Register ourselves with the allocator, potentially replacing the
|
||||
# physical block index.
|
||||
if self.content_hash is not None:
|
||||
self.block_id = (self._prefix_caching_allocator.
|
||||
promote_to_immutable_block(self))
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block.block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value) -> None:
|
||||
self._block.block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self._block.is_full
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block.num_empty_slots
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block.block_size
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._block.token_ids
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional[Block]:
|
||||
return self._prev_block
|
||||
|
||||
@property
|
||||
def content_hash(self) -> Optional[int]:
|
||||
"""Return the content-based hash of the current block, or None if it is
|
||||
not yet defined.
|
||||
|
||||
For the content-based hash to be defined, the current block must be
|
||||
full.
|
||||
"""
|
||||
|
||||
# If the hash is already computed, return it.
|
||||
if self._cached_content_hash is not None:
|
||||
return self._cached_content_hash
|
||||
|
||||
# We cannot compute a hash for the current block because it is not full.
|
||||
if not self.is_full:
|
||||
return None
|
||||
|
||||
is_first_block = self._prev_block is None
|
||||
prev_block_hash = (None if is_first_block else
|
||||
self._prev_block.content_hash)
|
||||
|
||||
# Previous block exists but does not yet have a hash.
|
||||
# Return no hash in this case.
|
||||
if prev_block_hash is None and not is_first_block:
|
||||
return None
|
||||
|
||||
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block,
|
||||
prev_block_hash,
|
||||
cur_block_token_ids=self.token_ids)
|
||||
return self._cached_content_hash
|
||||
|
||||
@staticmethod
|
||||
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
|
||||
cur_block_token_ids: List[int]) -> int:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching.
|
||||
|
||||
NOTE: Content-based hashing does not yet support LoRA.
|
||||
|
||||
Parameters:
|
||||
- is_first_block (bool): A flag indicating if the block is the first in
|
||||
the sequence.
|
||||
- prev_block_hash (Optional[int]): The hash of the previous block. None
|
||||
if this is the first block.
|
||||
- cur_block_token_ids (List[int]): A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
|
||||
Returns:
|
||||
- int: The computed hash value for the block.
|
||||
"""
|
||||
assert (prev_block_hash is None) == is_first_block
|
||||
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
|
||||
|
||||
|
||||
def assert_prefix_caching_block_or_none(block: Optional[Block]):
|
||||
if block is None:
|
||||
return
|
||||
assert isinstance(block, PrefixCachingBlock)
|
@ -1,5 +1,4 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
@ -7,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
@ -196,21 +196,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
"""Result for BlockSpaceManager.can_allocate
|
||||
|
||||
1. Ok: seq_group can be allocated now.
|
||||
2. Later: seq_group cannot be allocated.
|
||||
The capacity of allocator is larger than seq_group required.
|
||||
3. Never: seq_group can never be allocated.
|
||||
The seq_group is too large to allocated in GPU.
|
||||
"""
|
||||
OK = enum.auto()
|
||||
LATER = enum.auto()
|
||||
NEVER = enum.auto()
|
||||
|
||||
|
||||
class BlockSpaceManager:
|
||||
class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
"""Manages the mapping between logical and physical token blocks."""
|
||||
|
||||
def __init__(
|
||||
@ -355,6 +341,11 @@ class BlockSpaceManager:
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> PhysicalTokenBlock:
|
||||
# Called before a new block is appended.
|
||||
# This is in charge of allocating a new physical block (to be appended).
|
||||
|
||||
# None if the last block is not full. Otherwise, we set it to the
|
||||
# content hash.
|
||||
if not self.enable_caching:
|
||||
return self.gpu_allocator.allocate()
|
||||
block_hash: Optional[int] = None
|
||||
@ -362,7 +353,14 @@ class BlockSpaceManager:
|
||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
num_hashed_tokens = seq.num_hashed_tokens_of_block(
|
||||
len(seq.logical_token_blocks) - 1)
|
||||
|
||||
# num_hashed_tokens is used to compute future hashes
|
||||
# (e.g. in the hashing function, it is used to ask the sequence for
|
||||
# prefix tokens)
|
||||
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
|
||||
|
||||
# If the block has is None, then the block is not full.
|
||||
# If the block is not full, then we expect it to have a refcount of 1.
|
||||
if block_hash is None:
|
||||
assert new_block.ref_count == 1
|
||||
return new_block
|
||||
@ -576,16 +574,16 @@ class BlockSpaceManager:
|
||||
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
||||
]
|
||||
|
||||
def get_common_computed_block_ids(self,
|
||||
seq_group: SequenceGroup) -> List[int]:
|
||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Used in prefill (can skip prefill of some blocks).
|
||||
"""
|
||||
# Can return non-empty result only with prefix caching enabled.
|
||||
if not self.enable_caching:
|
||||
return []
|
||||
|
||||
ids_list = [
|
||||
self.get_all_computed_blocks(seq)
|
||||
for seq in iter(seq_group.seqs_dict.values())
|
||||
]
|
||||
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
210
vllm/core/block_manager_v2.py
Normal file
210
vllm/core/block_manager_v2.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
SeqId = int
|
||||
|
||||
|
||||
class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
"""BlockSpaceManager which manages the allocation of KV cache.
|
||||
|
||||
It owns responsibility for allocation, swapping, allocating memory for
|
||||
autoregressively-generated tokens, and other advanced features such as
|
||||
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
|
||||
|
||||
The current implementation is partial; in particular prefix caching and
|
||||
sliding-window are not feature complete. This class implements the design
|
||||
described in https://github.com/vllm-project/vllm/pull/3492.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of each memory block.
|
||||
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
||||
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
|
||||
watermark (float, optional): The threshold used for memory swapping.
|
||||
Defaults to 0.01.
|
||||
sliding_window (Optional[int], optional): The size of the sliding
|
||||
window. Defaults to None.
|
||||
enable_caching (bool, optional): Flag indicating whether caching is
|
||||
enabled. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
assert sliding_window is None, "Sliding window not yet supported"
|
||||
|
||||
self.block_sliding_window = None
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
assert not enable_caching, "Prefix caching not yet supported"
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||
# Currently, only naive blocks are supported (no prefix caching).
|
||||
allocator_type="naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
|
||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||
seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
assert self.block_sliding_window is None
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
device=Device.GPU)
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||
self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
assert not (set(seq.seq_id for seq in waiting_seqs)
|
||||
& self.block_tables.keys()), "block table already exists"
|
||||
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = waiting_seqs[0]
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=self.block_size,
|
||||
block_allocator=self.block_allocator,
|
||||
)
|
||||
assert self.block_sliding_window is None
|
||||
block_table.allocate(seq.get_token_ids())
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in waiting_seqs[1:]:
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
# Simple heuristic: If there is at least one free block
|
||||
# for each sequence, we can append.
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU)
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
return num_seqs <= num_free_gpu_blocks
|
||||
|
||||
def append_slot(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
# Get unseen token ids.
|
||||
num_full_slots = block_table.num_full_slots
|
||||
unseen_token_ids = seq.get_token_ids()[num_full_slots:]
|
||||
assert unseen_token_ids
|
||||
|
||||
block_table.append_token_ids(unseen_token_ids)
|
||||
|
||||
# Return any copy-on-writes.
|
||||
_ = self.block_allocator.clear_copy_on_writes()
|
||||
|
||||
# TODO extend append_slot interface to append_slots
|
||||
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
|
||||
|
||||
return None
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
self.block_tables[seq.seq_id].free()
|
||||
del self.block_tables[seq.seq_id]
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
assert seq.seq_id in self.block_tables
|
||||
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
||||
assert all(b is not None for b in block_ids)
|
||||
return block_ids
|
||||
|
||||
def access_all_blocks_in_seq(self, seq, now):
|
||||
# TODO add prefix caching support.
|
||||
# Tracked here https://github.com/vllm-project/vllm/issues/3667
|
||||
pass
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
# We ignore the sequence group as its not necessary. After the batch is
|
||||
# formed by the scheduler, we do not need to mark blocks from individual
|
||||
# sequence groups as computed -- all blocks in the batch can be marked
|
||||
# as computed.
|
||||
self.block_allocator.mark_blocks_as_computed()
|
||||
|
||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
||||
"""Determine which blocks for which we skip prefill.
|
||||
|
||||
With prefix caching we can skip prefill for previously-generated blocks.
|
||||
Currently, the attention implementation only supports skipping cached
|
||||
blocks if they are a contiguous prefix of cached blocks.
|
||||
|
||||
This method determines which blocks can be safely skipped for all
|
||||
sequences in the sequence group.
|
||||
"""
|
||||
seq_block_ids = [
|
||||
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
|
||||
]
|
||||
return self.block_allocator.get_common_computed_block_ids(
|
||||
seq_block_ids)
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
return False
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
return False
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.GPU)
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
107
vllm/core/interfaces.py
Normal file
107
vllm/core/interfaces.py
Normal file
@ -0,0 +1,107 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
"""Result for BlockSpaceManager.can_allocate
|
||||
|
||||
1. Ok: seq_group can be allocated now.
|
||||
2. Later: seq_group cannot be allocated.
|
||||
The capacity of allocator is larger than seq_group required.
|
||||
3. Never: seq_group can never be allocated.
|
||||
The seq_group is too large to allocated in GPU.
|
||||
"""
|
||||
OK = enum.auto()
|
||||
LATER = enum.auto()
|
||||
NEVER = enum.auto()
|
||||
|
||||
|
||||
class BlockSpaceManager(ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_block_space_manager_class(version: str):
|
||||
version = version.lower()
|
||||
|
||||
if version == "v1":
|
||||
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
|
||||
return BlockSpaceManagerV1
|
||||
|
||||
if version == "v2":
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
return BlockSpaceManagerV2
|
||||
|
||||
raise ValueError(f"Unknown version {version=}")
|
||||
|
||||
@abstractmethod
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_slot(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
pass
|
@ -4,7 +4,7 @@ from collections import deque
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -88,8 +88,13 @@ class Scheduler:
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
|
||||
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
|
||||
version="v2" if self.scheduler_config.
|
||||
use_v2_block_manager else "v1")
|
||||
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManager(
|
||||
self.block_manager = BlockSpaceManagerImpl(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
@ -378,6 +383,10 @@ class Scheduler:
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
self.block_manager.access_all_blocks_in_seq(seq, now)
|
||||
|
||||
common_computed_block_nums = (
|
||||
self.block_manager.get_common_computed_block_ids(
|
||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=scheduler_outputs.prompt_run,
|
||||
@ -385,8 +394,7 @@ class Scheduler:
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
lora_request=seq_group.lora_request,
|
||||
computed_block_nums=self.block_manager.
|
||||
get_common_computed_block_ids(seq_group),
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
state=seq_group.state,
|
||||
# `multi_modal_data` will only be present for the 1st comm
|
||||
# between engine and worker.
|
||||
@ -396,6 +404,14 @@ class Scheduler:
|
||||
if scheduler_outputs.prompt_run else None,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Now that the batch has been created, we can assume all blocks in the
|
||||
# batch will have been computed before the next scheduling invocation.
|
||||
# This is because the engine assumes that a failure in model execution
|
||||
# will crash the vLLM instance / will not retry.
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
self.block_manager.mark_blocks_as_computed(seq_group)
|
||||
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
@ -503,9 +519,6 @@ class Scheduler:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
self.block_manager.mark_blocks_as_computed(seq_group)
|
||||
|
||||
def _passed_delay(self, now: float) -> bool:
|
||||
if self.prev_prompt:
|
||||
self.last_prompt_latency = now - self.prev_time
|
||||
|
@ -28,6 +28,7 @@ class EngineArgs:
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: int = 16
|
||||
enable_prefix_caching: bool = False
|
||||
use_v2_block_manager: bool = False
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
@ -52,6 +53,9 @@ class EngineArgs:
|
||||
max_cpu_loras: Optional[int] = None
|
||||
device: str = 'auto'
|
||||
ray_workers_use_nsight: bool = False
|
||||
|
||||
forced_num_gpu_blocks: Optional[int] = None
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
image_input_type: Optional[str] = None
|
||||
image_token_id: Optional[int] = None
|
||||
@ -194,6 +198,9 @@ class EngineArgs:
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='Enables automatic prefix caching')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2')
|
||||
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
@ -210,6 +217,12 @@ class EngineArgs:
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
parser.add_argument(
|
||||
'--forced-num-gpu-blocks',
|
||||
type=int,
|
||||
default=None,
|
||||
help='If specified, ignore GPU profiling result and use this number'
|
||||
'of GPU blocks. Used for testing preemption.')
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
@ -369,6 +382,7 @@ class EngineArgs:
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
self.forced_num_gpu_blocks,
|
||||
model_config.get_sliding_window(),
|
||||
self.enable_prefix_caching)
|
||||
parallel_config = ParallelConfig(
|
||||
@ -383,6 +397,7 @@ class EngineArgs:
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.use_v2_block_manager,
|
||||
self.scheduler_delay_factor)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
|
@ -553,12 +553,6 @@ class LLMEngine:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
|
||||
# If prefix caching is enabled, mark all blocks in the sequence groups
|
||||
# as completed so that future requests don't attempt to recompute them
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
for seq_group in scheduled_seq_groups:
|
||||
self.scheduler.mark_blocks_as_computed(seq_group)
|
||||
|
||||
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
|
@ -85,6 +85,12 @@ class GPUExecutor(ExecutorBase):
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
))
|
||||
|
||||
if self.cache_config.forced_num_gpu_blocks is not None:
|
||||
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
|
||||
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
|
||||
f"{forced_num_gpu_blocks=}")
|
||||
num_gpu_blocks = forced_num_gpu_blocks
|
||||
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
|
@ -232,6 +232,13 @@ class RayGPUExecutor(ExecutorBase):
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
|
||||
if self.cache_config.forced_num_gpu_blocks is not None:
|
||||
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
|
||||
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
|
||||
f"{forced_num_gpu_blocks=}")
|
||||
num_gpu_blocks = forced_num_gpu_blocks
|
||||
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
|
@ -196,6 +196,8 @@ class Sequence:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# TODO This can produce incorrect hash when block size > prompt size
|
||||
|
||||
# Compute the number of tokens in the sequence
|
||||
# TODO: The current hashing function is O(L^2). We should optimize
|
||||
# this in the future.
|
||||
|
@ -227,6 +227,16 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
||||
|
||||
|
||||
def chunk_list(lst, chunk_size):
|
||||
"""Yield successive chunk_size chunks from lst."""
|
||||
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_nvcc_cuda_version() -> Optional[Version]:
|
||||
cuda_home = os.environ.get('CUDA_HOME')
|
||||
|
Loading…
x
Reference in New Issue
Block a user