[Core] Optimize block_manager_v2 vs block_manager_v1 (to make V2 default) (#5602)

This commit is contained in:
Alexander Matveev 2024-07-01 23:10:37 -04:00 committed by GitHub
parent 54600709b6
commit 3476ed0809
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1182 additions and 525 deletions

View File

@ -46,6 +46,7 @@ def main(args: argparse.Namespace):
load_format=args.load_format,
distributed_executor_backend=args.distributed_executor_backend,
otlp_traces_endpoint=args.otlp_traces_endpoint,
enable_prefix_caching=args.enable_prefix_caching,
)
sampling_params = SamplingParams(
@ -220,6 +221,9 @@ if __name__ == '__main__':
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument("--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching")
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",

View File

@ -474,7 +474,7 @@ class VllmRunner:
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))

View File

@ -373,8 +373,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int,
block_size) - (sequence_len // block_size)
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
original_block_ids = original_block_table.physical_block_ids
original_block_ids = original_block_table.physical_block_ids[:]
print("original_block_ids = {}".format(original_block_ids))
forked_block_table = original_block_table.fork()
# Expect no additional allocation (copy on _write_).
@ -457,7 +458,7 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
# Allocate lookahead slots.
original_block_table.ensure_num_empty_slots(lookahead_slots)
original_block_ids = original_block_table.physical_block_ids
original_block_ids = original_block_table.physical_block_ids[:]
forked_block_table = original_block_table.fork()

View File

@ -8,7 +8,7 @@ from vllm.utils import Device, chunk_list
@pytest.mark.parametrize("num_gpu_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
block_size: int, allocator_type: str):
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
@ -21,14 +21,14 @@ def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
cpu_blocks = [
allocator.allocate_mutable(prev_block=None, device=Device.CPU)
allocator.allocate_mutable_block(prev_block=None, device=Device.CPU)
for _ in range(num_cpu_blocks)
]
assert allocator.get_num_free_blocks(Device.CPU) == 0
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
gpu_blocks = [
allocator.allocate_mutable(prev_block=None, device=Device.GPU)
allocator.allocate_mutable_block(prev_block=None, device=Device.GPU)
for _ in range(num_gpu_blocks)
]
assert allocator.get_num_free_blocks(Device.CPU) == 0
@ -47,7 +47,7 @@ def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
@pytest.mark.parametrize("num_gpu_blocks", [1024])
@pytest.mark.parametrize("block_size", [2])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
block_size: int, allocator_type: str):
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
@ -67,7 +67,7 @@ def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
cpu_blocks = [
allocator.allocate_immutable(prev_block=None,
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids,
device=Device.CPU)
for token_ids in cpu_token_ids
@ -76,7 +76,7 @@ def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
gpu_blocks = [
allocator.allocate_immutable(prev_block=None,
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids,
device=Device.GPU)
for token_ids in gpu_token_ids

View File

@ -14,11 +14,11 @@ class TestNaiveBlockAllocator:
prev_block: Optional[Block],
token_ids: List[int]):
if allocate_type == "immutable":
allocate_block = lambda: allocator.allocate_immutable(
allocate_block = lambda: allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=token_ids)
elif allocate_type == "mutable":
allocate_block = lambda: allocator.allocate_mutable(prev_block=
prev_block)
allocate_block = lambda: allocator.allocate_mutable_block(
prev_block=prev_block)
else:
raise ValueError()

View File

@ -26,11 +26,10 @@ class TestPrefixCachingBlock:
token_ids = list(range(num_to_fill))
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
block_with_prev = PrefixCachingBlock(
prev_block=None,
block_with_prev = PrefixCachingBlock(prev_block=None,
token_ids=token_ids,
block_size=block_size,
prefix_caching_allocator=mock_allocator)
allocator=mock_allocator)
if is_curr_block_full:
# Expect hash since block is full.
@ -71,7 +70,7 @@ class TestPrefixCachingBlock:
prev_block=previous_block,
token_ids=token_ids,
block_size=block_size,
prefix_caching_allocator=mock_allocator,
allocator=mock_allocator,
)
if is_curr_block_full and prev_block_has_hash:
@ -138,7 +137,7 @@ class TestPrefixCachingBlock:
prev_block=prev_block,
token_ids=[],
block_size=block_size,
prefix_caching_allocator=allocator,
allocator=allocator,
)
tokens_to_append = token_ids[block_number *
@ -159,11 +158,11 @@ class TestPrefixCachingBlockAllocator:
prev_block: Optional[Block],
token_ids: List[int]):
if allocate_type == "immutable":
allocate_block = lambda: allocator.allocate_immutable(
allocate_block = lambda: allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=token_ids)
elif allocate_type == "mutable":
allocate_block = lambda: allocator.allocate_mutable(prev_block=
prev_block)
allocate_block = lambda: allocator.allocate_mutable_block(
prev_block=prev_block)
else:
raise ValueError()
@ -233,12 +232,13 @@ class TestPrefixCachingBlockAllocator:
# Expect allocation with unseen hash to fail.
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_immutable(prev_block=chain[-1],
token_ids=list(range(block_size)))
allocator.allocate_immutable_block(prev_block=chain[-1],
token_ids=list(
range(block_size)))
# Expect mutable allocation to fail.
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_mutable(prev_block=chain[-1])
allocator.allocate_mutable_block(prev_block=chain[-1])
# Expect allocation of exact same chain to pass.
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
@ -270,7 +270,7 @@ class TestPrefixCachingBlockAllocator:
# Expect mutable allocation to fail.
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_mutable(prev_block=None)
allocator.allocate_mutable_block(prev_block=None)
block_to_free = chain[-1]
@ -280,11 +280,11 @@ class TestPrefixCachingBlockAllocator:
allocator.free(block_to_free)
assert block_to_free.block_id is None, i
new_block = allocator.allocate_mutable(prev_block=None)
new_block = allocator.allocate_mutable_block(prev_block=None)
assert new_block.block_id == block_id, i
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_mutable(prev_block=None)
allocator.allocate_mutable_block(prev_block=None)
block_to_free = new_block
@ -376,7 +376,6 @@ class TestPrefixCachingBlockAllocator:
# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks_to_consume * block_size))
blocks = list(range(num_blocks_to_consume))
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
@ -384,9 +383,6 @@ class TestPrefixCachingBlockAllocator:
allocator=allocator,
)
# mark all blocks in first chain as computed
allocator.mark_blocks_as_computed(blocks)
# After zero_point, second_chain's token_ids would be set -1, which
# make it different from here comparing with first_chain
zero_point = random.randint(1, len(token_ids) - 1)
@ -424,15 +420,16 @@ class TestPrefixCachingBlockAllocator:
block_size=block_size)
token_ids = list(range(block_size))
block = allocator.allocate_immutable(prev_block=None,
block = allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
assert allocator._refcounter.get(block.block_id) == 1
m = allocator.allocate_mutable(prev_block=None)
m = allocator.allocate_mutable_block(prev_block=None)
block_id = m.block_id
for i in range(block_size):
m.append_token_ids([i])
# After block get promoted to immutable from mutable, if there is
# already same content hash block, then it shall be released into
# hashless_allocator
@ -452,48 +449,79 @@ class TestPrefixCachingBlockAllocator:
all_blocks_list = [i for i in range(num_blocks)]
zero_ref = {i: 0 for i in range(num_blocks)}
one_ref = {i: 1 for i in range(num_blocks)}
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(num_blocks * block_size))
# now we have num_blocks free blocks in hashless allocator
# with internal tracking list _blocks _cached_blocks and evictor
# empty and block's ref shall be 0
# Verify initial/pre-alloc state
# Ensure all blocks are free inside hashless allocator
assert list(allocator._hashless_allocator._free_block_indices
) == all_blocks_list
assert len(allocator._blocks.keys()) == 0
# Ensure no tracked blocks
assert len(allocator._block_tracker.keys()) == num_blocks
for block_id in range(num_blocks):
assert not allocator._block_tracker[block_id].active
# Ensure no cached blocks
assert len(allocator._cached_blocks.values()) == 0
# Ensure no evicted blocks
assert len(allocator.evictor.free_table.keys()) == 0
# Ensure 0s ref counts for all blocks
assert allocator._refcounter._refcounts == zero_ref
# Allocate immutable chains with only one block residuled in
new_block = []
for i in range(num_blocks):
block = allocator.allocate_immutable(
block = allocator.allocate_immutable_block(
prev_block=None,
token_ids=token_ids[block_size * i:block_size * (i + 1)])
new_block.append(block)
# Verify post-alloc state
# Ensure no blocks are free inside hashless allocator
assert (len(allocator._hashless_allocator._free_block_indices) == 0)
# Ensure all blocks are tracked
assert len(allocator._block_tracker.keys()) == num_blocks
for block_id in range(num_blocks):
assert allocator._block_tracker[block_id].active
# Ensure all blocks are cached (all promoted)
assert len(allocator._cached_blocks.values()) == num_blocks
# Ensure no evicted blocks
assert len(allocator.evictor.free_table.keys()) == 0
# Ensure 1s ref counts for all blocks
assert allocator._refcounter._refcounts == one_ref
# Free all blocks, and now all blocks shall be in the evictor
# there shall be no tracking data left in _blocks
# there shall be no tracking data left in _block_tracker
# all blocks shall be tracked in _cached_blocks
# all blocks' ref shall be zero
for block in new_block:
allocator.free(block)
assert len(allocator._blocks.keys()) == 0
# Verify post-free state
# Ensure no tracked blocks
assert len(allocator._block_tracker.keys()) == num_blocks
for block_id in range(num_blocks):
assert not allocator._block_tracker[block_id].active
# Ensure no blocks in hashless allocator (all promoted)
assert len(allocator._hashless_allocator._free_block_indices) == 0
# Ensure all blocks are cached
assert list(allocator._cached_blocks.values()) == all_blocks_list
# Ensure all blocks are inside the evictor
assert list(allocator.evictor.free_table.keys()) == all_blocks_list
# Ensure 0s refcounts
assert allocator._refcounter._refcounts == zero_ref
# Allocate a mutable block, and the first block shall be evicted
# and set its content hash into None, ref to 1
mutable = allocator.allocate_mutable(prev_block=None)
mutable = allocator.allocate_mutable_block(prev_block=None)
assert mutable.block_id == 0
assert mutable.content_hash is None
assert 0 in allocator._blocks
assert allocator._block_tracker[0].active
assert allocator._refcounter.get(0) == 1
assert 0 not in allocator._cached_blocks
assert 0 not in allocator.evictor
@ -502,27 +530,27 @@ class TestPrefixCachingBlockAllocator:
# hashless allocator
allocator.free(mutable)
assert len(allocator._blocks.keys()) == 0
assert not allocator._block_tracker[0].active
assert allocator._refcounter._refcounts == zero_ref
assert 0 not in allocator._cached_blocks
assert 0 not in allocator.evictor
assert 0 in allocator._hashless_allocator._free_block_indices
# when allocate immutable with first block_size tokens, we
# When allocate immutable with first block_size tokens, we
# shall get free block from hashless allocator, thus no block left
# in hashless
block = allocator.allocate_immutable(prev_block=None,
token_ids=token_ids[:block_size])
block = allocator.allocate_immutable_block(
prev_block=None, token_ids=token_ids[:block_size])
assert block.block_id == 0
assert len(allocator._hashless_allocator._free_block_indices) == 0
assert 0 in allocator._blocks
assert allocator._block_tracker[0].active
assert 0 in allocator._cached_blocks.values()
assert allocator._refcounter.get(0) == 1
assert 0 not in allocator.evictor
# allocate mutable block again, it shall be popped from evictor
mutable = allocator.allocate_mutable(prev_block=None)
mutable = allocator.allocate_mutable_block(prev_block=None)
assert len(allocator._hashless_allocator._free_block_indices) == 0
assert mutable.block_id not in allocator.evictor.free_table
assert allocator._refcounter.get(mutable.block_id) == 1
@ -619,7 +647,7 @@ class TestPrefixCachingBlockAllocator:
block_token_ids = token_ids[block_number *
block_size:(block_number + 1) *
block_size]
prev_block = allocator.allocate_immutable(
prev_block = allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=block_token_ids)
blocks.append(prev_block)

View File

@ -90,10 +90,10 @@ def test_create_single_target_seq_group_metadata(k: int):
assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids(
) == prompt_tokens
assert output.seq_data[target_seq_id].get_output_token_ids(
) == prev_output_tokens + token_ids
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
prompt_tokens)
assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
prev_output_tokens + token_ids)
assert len(output.block_tables) == 1
assert output.block_tables[

View File

@ -1,5 +1,6 @@
from typing import List, Optional
from vllm.core.block.common import BlockList
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
@ -47,12 +48,10 @@ class BlockTable:
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
self._blocks: BlockList = BlockList(_blocks)
self._max_block_sliding_window = max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
self._num_full_slots = self._get_num_token_ids()
@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
@ -88,11 +87,18 @@ class BlockTable:
"""
assert not self._is_allocated
assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self.update(blocks)
self._num_full_slots = len(token_ids)
def update(self, blocks: List[Block]) -> None:
"""Resets the table to the newly provided blocks
(with their corresponding block ids)
"""
self._blocks.update(blocks)
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0,
@ -140,11 +146,11 @@ class BlockTable:
num_lookahead_slots)
# Update the blocks with the new tokens
blocks = self._blocks[self._num_full_slots // self._block_size:]
first_block_idx = self._num_full_slots // self._block_size
token_blocks = self._chunk_token_blocks_for_append(token_ids)
for block, token_block in zip(blocks, token_blocks):
block.append_token_ids(token_block)
for i, token_block in enumerate(token_blocks):
self._blocks.append_token_ids(first_block_idx + i, token_block)
self._num_full_slots += len(token_ids)
@ -174,8 +180,8 @@ class BlockTable:
for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1], device=device))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
@ -209,12 +215,12 @@ class BlockTable:
is set to `None`.
"""
assert self._is_allocated
for block in self._blocks:
for block in self.blocks:
self._allocator.free(block)
self._blocks = []
self._blocks.reset()
@property
def physical_block_ids(self) -> List[Optional[int]]:
def physical_block_ids(self) -> List[int]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
@ -228,7 +234,7 @@ class BlockTable:
BlockTable.
"""
assert self._is_allocated
return [block.block_id for block in self._blocks]
return self._blocks.ids()
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence.
@ -253,17 +259,31 @@ class BlockTable:
token_ids: List[int],
device: Device) -> List[Block]:
blocks: List[Block] = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
prev_block = self._allocator.allocate_immutable(
prev_block, token_ids=block_token_ids, device=device)
block_token_ids = []
tail_token_ids = []
for cur_token_ids in chunk_list(token_ids, self._block_size):
if len(cur_token_ids) == self._block_size:
block_token_ids.append(cur_token_ids)
else:
# Else, partially fill a mutable block with token ids.
prev_block = self._allocator.allocate_mutable(
tail_token_ids.append(cur_token_ids)
if block_token_ids:
blocks.extend(
self._allocator.allocate_immutable_blocks(
prev_block, block_token_ids=block_token_ids,
device=device))
prev_block = blocks[-1]
if tail_token_ids:
assert len(tail_token_ids) == 1
cur_token_ids = tail_token_ids[0]
block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device)
prev_block.append_token_ids(block_token_ids)
blocks.append(prev_block)
block.append_token_ids(cur_token_ids)
blocks.append(block)
return blocks
@ -274,18 +294,25 @@ class BlockTable:
if not self._is_allocated:
return token_ids
for block in self._blocks:
for block in self.blocks:
token_ids.extend(block.token_ids)
return token_ids
def _get_num_token_ids(self) -> int:
res = 0
for block in self.blocks:
res += len(block.token_ids)
return res
@property
def _is_allocated(self) -> bool:
return len(self._blocks) > 0
@property
def blocks(self) -> Optional[List[Block]]:
return self._blocks
def blocks(self) -> List[Block]:
return self._blocks.list()
@property
def _num_empty_slots(self) -> int:

View File

@ -1,4 +1,5 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator
@ -95,64 +96,40 @@ class CopyOnWriteTracker:
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
conjunction with a RefCounter.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def __init__(
self,
refcounter: RefCounterProtocol,
allocator: BlockAllocator,
):
def __init__(self, refcounter: RefCounterProtocol):
self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
self._refcounter = refcounter
self._allocator = allocator
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
def is_appendable(self, block: Block) -> bool:
"""Checks if the block is shared or not. If shared, then it cannot
be appended and needs to be duplicated via copy-on-write
"""
block_id = block.block_id
if block_id is None:
return block_id
return True
refcount = self._refcounter.get(block_id)
assert refcount != 0
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)
return refcount <= 1
# Allocate a fresh new block.
block_id = self._allocator.allocate_mutable(
prev_block=block.prev_block).block_id
# Track src/dst copy.
def record_cow(self, src_block_id: Optional[BlockId],
trg_block_id: Optional[BlockId]) -> None:
"""Records a copy-on-write operation from source to target block id
Args:
src_block_id (BlockId): The source block id from which to copy
the data
trg_block_id (BlockId): The target block id to which the data
is copied
"""
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes.append((src_block_id, block_id))
return block_id
assert trg_block_id is not None
self._copy_on_writes.append((src_block_id, trg_block_id))
def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
"""Clears the copy-on-write tracking information and returns the current
@ -172,6 +149,139 @@ class CopyOnWriteTracker:
return cows
class BlockPool:
"""Used to pre-allocate block objects, in order to avoid excessive python
object allocations/deallocations.
The pool starts from "pool_size" objects and will increase to more objects
if necessary
Note that multiple block objects may point to the same physical block id,
which is why this pool is needed, so that it will be easier to support
prefix caching and more complicated sharing of physical blocks.
"""
def __init__(self, block_size: int, create_block: Block.Factory,
allocator: BlockAllocator, pool_size: int):
self._block_size = block_size
self._create_block = create_block
self._allocator = allocator
self._pool_size = pool_size
assert self._pool_size >= 0
self._free_ids: Deque[int] = deque(range(self._pool_size))
self._pool = []
for i in range(self._pool_size):
self._pool.append(
self._create_block(prev_block=None,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
def increase_pool(self):
"""Doubles the internal pool size
"""
cur_pool_size = self._pool_size
new_pool_size = cur_pool_size * 2
self._pool_size = new_pool_size
self._free_ids += deque(range(cur_pool_size, new_pool_size))
for i in range(cur_pool_size, new_pool_size):
self._pool.append(
self._create_block(prev_block=None,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
def init_block(self, prev_block: Optional[Block], token_ids: List[int],
block_size: int, physical_block_id: Optional[int]) -> Block:
if len(self._free_ids) == 0:
self.increase_pool()
assert len(self._free_ids) > 0
pool_id = self._free_ids.popleft()
block = self._pool[pool_id]
block.__init__( # type: ignore[misc]
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
allocator=block._allocator, # type: ignore[attr-defined]
block_id=physical_block_id)
block.pool_id = pool_id # type: ignore[attr-defined]
return block
def free_block(self, block: Block) -> None:
self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined]
class BlockList:
"""This class is an optimization to allow fast-access to physical
block ids. It maintains a block id list that is updated with the
block list and this avoids the need to reconstruct the block id
list on every iteration of the block manager
"""
def __init__(self, blocks: List[Block]):
self._blocks: List[Block] = []
self._block_ids: List[int] = []
self.update(blocks)
def _add_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_ids.append(block_id)
def _update_block_id(self, block_index: int,
new_block_id: Optional[BlockId]) -> None:
assert new_block_id is not None
self._block_ids[block_index] = new_block_id
def update(self, blocks: List[Block]):
self._blocks = blocks
# Cache block ids for fast query
self._block_ids = []
for block in self._blocks:
self._add_block_id(block.block_id)
def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
block = self._blocks[block_index]
prev_block_id = block.block_id
block.append_token_ids(token_ids)
# CoW or promotion may update the internal block_id
if prev_block_id != block.block_id:
self._update_block_id(block_index, block.block_id)
def append(self, new_block: Block):
self._blocks.append(new_block)
self._add_block_id(new_block.block_id)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, block_index: int) -> Block:
return self._blocks[block_index]
def __setitem__(self, block_index: int, new_block: Block) -> None:
self._blocks[block_index] = new_block
self._update_block_id(block_index, new_block.block_id)
def reset(self):
self._blocks = []
self._block_ids = []
def list(self) -> List[Block]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.

View File

@ -113,10 +113,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def allocate_or_get_null_block(self) -> Block:
if self._null_block is None:
self._null_block = NullBlock(
self.allocate_mutable(None, Device.GPU))
self.allocate_mutable_block(None, Device.GPU))
return self._null_block
def allocate_mutable(self, prev_block: Optional[Block],
def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable(prev_block)
return self._allocators[device].allocate_mutable_block(prev_block)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device]) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids)
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable(
return self._allocators[device].allocate_immutable_block(
prev_block, token_ids)
def free(self, block: Block) -> None:
@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block)
allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""
return self._allocators[device].get_physical_block_id(absolute_id)
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
def swap(self, blocks: List[Block], src_device: Device,
dst_device: Device) -> Dict[int, int]:
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
blocks: List of blocks to be swapped.
source_device (Device): Device to swap the 'blocks' from.
dest_device (Device): Device to swap the 'blocks' to.
src_device (Device): Device to swap the 'blocks' from.
dst_device (Device): Device to swap the 'blocks' to.
Returns:
Dict[int, int]: Swap mapping from source_device
on to dest_device.
"""
source_block_ids = [block.block_id for block in blocks]
self._allocators[source_device].swap_out(blocks)
self._allocators[dest_device].swap_in(blocks)
dest_block_ids = [block.block_id for block in blocks]
src_block_ids = [block.block_id for block in blocks]
self._allocators[src_device].swap_out(blocks)
self._allocators[dst_device].swap_in(blocks)
dst_block_ids = [block.block_id for block in blocks]
current_swap_mapping: Dict[int, int] = {}
for src, dest in zip(source_block_ids, dest_block_ids):
if src is not None and dest is not None:
self._swap_mapping[src] = dest
current_swap_mapping[src] = dest
for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
if src_block_id is not None and dst_block_id is not None:
self._swap_mapping[src_block_id] = dst_block_id
current_swap_mapping[src_block_id] = dst_block_id
return current_swap_mapping
def get_num_blocks_touched(self,
@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_computed_block_ids(
prev_computed_block_ids, block_ids, skip_last_block_id)
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
computed_seq_block_ids)
@property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
@ -341,6 +364,11 @@ class NullBlock(Block):
def token_ids(self) -> List[BlockId]:
return self._proxy.token_ids
@property
def num_tokens_total(self) -> int:
raise NotImplementedError(
"num_tokens_total is not used for null block")
@property
def num_empty_slots(self) -> BlockId:
return self._proxy.num_empty_slots

View File

@ -28,6 +28,13 @@ class Block(ABC):
def token_ids(self) -> List[int]:
pass
@property
@abstractmethod
def num_tokens_total(self) -> int:
"""The number of tokens till the current block (inclusive)
"""
pass
@property
@abstractmethod
def num_empty_slots(self) -> int:
@ -92,14 +99,20 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
pass
@abstractmethod
def allocate_immutable_blocks(
self, prev_block: Optional[Block],
block_token_ids: List[List[int]]) -> List[Block]:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@ -147,12 +160,18 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
@ -174,13 +193,20 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device) -> List[Block]:
pass
@abstractmethod
@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC):
pass
@abstractmethod
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
def swap(self, blocks: List[Block], src_device: Device,
dst_device: Device) -> Dict[int, int]:
pass
@abstractmethod

View File

@ -1,6 +1,7 @@
from typing import FrozenSet, Iterable, List, Optional, Set, Tuple
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.utils import cdiv
@ -31,25 +32,36 @@ class NaiveBlockAllocator(BlockAllocator):
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
block_pool: Optional[BlockPool] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices: Set[BlockId] = set(block_ids)
self._free_block_indices: Deque[BlockId] = deque(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter(
all_block_indices=self._free_block_indices)
self._create_block = create_block
self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
refcounter=self._refcounter.as_readonly())
def allocate_immutable(self,
if block_pool is None:
extra_factor = 4
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
self._block_pool = BlockPool(self._block_size, create_block, self,
num_blocks * extra_factor)
else:
# In this case, the block pool is provided by the caller,
# which means that there is most likely a need to share
# a block pool between allocators
self._block_pool = block_pool
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
@ -66,11 +78,34 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated immutable block.
"""
assert device is None
block = self.allocate_mutable(prev_block=prev_block)
block = self.allocate_mutable_block(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self,
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device] = None) -> List[Block]:
assert device is None
num_blocks = len(block_token_ids)
block_ids = []
for i in range(num_blocks):
block_ids.append(self._allocate_block_id())
blocks = []
for i in range(num_blocks):
prev_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block_token_ids[i],
block_size=self._block_size,
physical_block_id=block_ids[i])
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated mutable block.
"""
assert device is None
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
block_id = self._allocate_block_id()
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_id=block_id,
block_size=self._block_size,
allocator=self,
)
physical_block_id=block_id)
return block
def _allocate_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = self._free_block_indices.popleft()
self._refcounter.incr(block_id)
return block_id
def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id)
block.block_id = None
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
# Release the block object
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator):
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
forked_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
physical_block_id=block.block_id)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
return forked_blocks
@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator):
def get_num_total_blocks(self) -> int:
return len(self._all_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = next(iter(self._free_block_indices))
self._refcounter.incr(block_id)
self._free_block_indices.remove(block_id)
return block_id
def _free_block_id(self, block_id: BlockId) -> None:
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.add(block_id)
def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator):
def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
BlockId: The block index of the new block if a copy-on-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator):
"""
pass
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
"""No prefix caching here => return empty list
"""
return []
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator):
return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
raise NotImplementedError("There is no promotion for naive blocks")
def get_num_blocks_touched(self,
blocks: List[Block],
@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator):
def swap_out(self, blocks: List[Block]) -> None:
for block in blocks:
self.free(block)
self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None:
for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full:
alloc = self.allocate_immutable(block.prev_block,
block.token_ids)
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids)
else:
alloc = self.allocate_mutable(block.prev_block)
alloc.append_token_ids(block.token_ids)
block.block_id = alloc.block_id
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
tmp_block.block_id = None
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
class NaiveBlock(Block):
@ -315,11 +382,12 @@ class NaiveBlock(Block):
self._append_token_ids_no_cow(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block, instructing the allocator
to perform a copy-on-write if necessary.
"""Appends the given token IDs to the block and performs a
copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
token_ids (Optional[List[int]]): The token IDs to be appended
to the block.
"""
self._append_token_ids_no_cow(token_ids)
@ -328,7 +396,16 @@ class NaiveBlock(Block):
self._cow_target))
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids)
"""Appends the given token IDs to the block
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
if len(token_ids) == 0:
return
assert len(token_ids) <= self.num_empty_slots
self._token_ids.extend(token_ids)
@property
@ -361,12 +438,17 @@ class NaiveBlock(Block):
@property
def num_empty_slots(self) -> int:
return self._block_size - len(self._token_ids)
return self._block_size - len(self.token_ids)
@property
def token_ids(self) -> List[int]:
return self._token_ids
@property
def num_tokens_total(self) -> int:
raise NotImplementedError(
"num_tokens_total is not used for naive block")
@property
def block_size(self) -> int:
return self._block_size

View File

@ -1,13 +1,13 @@
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.utils import cdiv
@ -19,6 +19,30 @@ PrefixHash = int
_DEFAULT_LAST_ACCESSED_TIME = -1
class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
"""
__slots__ = ("active", "last_accessed", "computed")
def reset(self):
self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self.computed: bool = False
def __init__(self):
self.active: bool = False
self.reset()
def enable(self):
assert not self.active
self.active = True
self.reset()
def disable(self):
assert self.active
self.active = False
self.reset()
class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching.
@ -41,12 +65,26 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
):
if block_ids is None:
block_ids = range(num_blocks)
self._block_size = block_size
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of blockId to Block to track those cached blocks
self._blocks: Dict[BlockId, Block] = {}
# Used to track status of each physical block id
self._block_tracker: Dict[BlockId, BlockTracker] = {}
for block_id in block_ids:
self._block_tracker[block_id] = BlockTracker()
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
extra_factor = 4
self._block_pool = BlockPool(self._block_size, self._create_block,
self, num_blocks * extra_factor)
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
@ -54,10 +92,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
block_pool=self._block_pool, # Share block pool here
)
self._block_size = block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
@ -68,9 +105,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
refcounter=self._refcounter.as_readonly())
# Implements Block.Factory.
def _create_block(
@ -90,11 +125,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
prefix_caching_allocator=allocator,
allocator=allocator,
computed=computed,
)
def allocate_immutable(self,
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
@ -111,27 +146,39 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert device is None
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
prev_block=prev_block,
# First, try to create a block that points to cached data
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
allocator=self,
)
physical_block_id=None)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
block.block_id = cached_block_id
self._incr_refcount_cached_block(block, block.block_id)
self._incr_refcount_cached_block(block)
return block
self._block_pool.free_block(block)
block = self.allocate_mutable(prev_block)
# No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block)
block.append_token_ids(token_ids)
assert block.content_hash is not None
return block
def allocate_mutable(self,
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device] = None) -> List[Block]:
blocks = []
for token_ids in block_token_ids:
prev_block = self.allocate_immutable_block(prev_block=prev_block,
token_ids=token_ids,
device=device)
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
@ -147,27 +194,112 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert device is None
assert_prefix_caching_block_or_none(prev_block)
try:
block = self._hashless_allocator.allocate_mutable(
prev_block=prev_block)
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
block_id = self._allocate_block_id()
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
physical_block_id=block_id)
assert not block.computed
assert block.content_hash is None
return block
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
# If the evictor has blocks available for eviction, evict a block
# and return it.
if self.evictor.num_blocks > 0:
# here we get an evicted block, which is only added
def _incr_refcount_cached_block(self, block: Block) -> None:
# Set this block to be "computed" since it is pointing to a
# cached block id (which was already computed)
block.computed = True
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.incr(block_id)
if refcount == 1:
# In case a cached block was evicted, restore its tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._track_block_id(block_id, computed=True)
def _decr_refcount_cached_block(self, block: Block) -> None:
# Ensure this is immutable/cached block
assert block.content_hash is not None
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount > 0:
block.block_id = None
return
else:
assert refcount == 0
# No longer used
assert block.content_hash in self._cached_blocks
# Add the cached block to the evictor
# (This keeps the cached block around so it can be reused)
self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
self._block_tracker[block_id].last_accessed)
# Stop tracking the block
self._untrack_block_id(block_id)
block.block_id = None
def _decr_refcount_hashless_block(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
# We may have a fork case where block is shared,
# in which case, we cannot remove it from tracking
refcount = self._refcounter.get(block_id)
if refcount == 1:
self._untrack_block_id(block_id)
# Decrement refcount of the block_id, but do not free the block object
# itself (will be handled by the caller)
self._hashless_allocator.free(block, keep_block_object=True)
def _allocate_block_id(self) -> BlockId:
"""First tries to allocate a block id from the hashless allocator,
and if there are no blocks, then tries to evict an unused cached block.
"""
hashless_block_id = self._maybe_allocate_hashless_block_id()
if hashless_block_id is not None:
return hashless_block_id
evicted_block_id = self._maybe_allocate_evicted_block_id()
if evicted_block_id is not None:
return evicted_block_id
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]:
try:
# Allocate mutable block and extract its block_id
block = self._hashless_allocator.allocate_mutable_block(
prev_block=None)
block_id = block.block_id
self._block_pool.free_block(block)
self._track_block_id(block_id, computed=False)
return block_id
except BlockAllocator.NoFreeBlocksError:
return None
def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
if self.evictor.num_blocks == 0:
return None
# Here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id, content_hash_to_evict = self.evictor.evict()
# Sanity checks
assert content_hash_to_evict in self._cached_blocks
_block_id = self._cached_blocks[content_hash_to_evict]
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id
@ -175,88 +307,41 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._cached_blocks.pop(content_hash_to_evict)
self._refcounter.incr(block_id)
self._track_block_id(block_id, computed=False)
# Now this block is pop from evictor and ready to write
# with new content which most probably different with
# original content. So need to tell worker to recompute
# its kvcache
block = self._create_block(
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
computed=False,
)
assert block.content_hash is None
return block_id
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _incr_refcount_cached_block(self, block: Block,
block_id: BlockId) -> None:
# now _incr_refcount_cached_block comes from two place
# allocate_immutable/promote_to_immutable_block where hit
# _cached_blocks hash key.
# In both cases, it means that already exists a already
# computed block which shared with block now
block.computed = True
refcount = self._refcounter.incr(block_id)
if refcount == 1:
# if block get referred, then it shall not be in evictor
# and put it into _blocks for tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._blocks[block_id] = block
def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is
zero, store the block in the freelist.
If the block has a content hash (meaning it is immutable), then we will
keep the block around in case future allocations require it.
def _free_block_id(self, block: Block) -> None:
"""Decrements the refcount of the block. The block may be in two
possible states: (1) immutable/cached or (2) mutable/hashless.
In the first case, the refcount is decremented directly and the block
may be possibly added to the evictor. In other case, hashless
allocator free(..) with keep_block_object=True is called to only free
the block id (since the block object may be reused by the caller)
"""
assert (block.block_id
is not None), "freeing unallocated block is undefined"
block_id = block.block_id
assert block_id is not None, "Freeing unallocated block is undefined"
self._free_block_id_for_block(block.block_id, block)
if block.content_hash is not None:
# Immutable: This type of block is always cached, and we want to
# keep it in the evictor for future reuse
self._decr_refcount_cached_block(block)
else:
# Mutable: This type of block is not cached, so we release it
# directly to the hashless allocator
self._decr_refcount_hashless_block(block)
block.block_id = None
assert block.block_id is None
def _free_block_id_for_block(self, block_id: BlockId,
block: Block) -> None:
assert isinstance(block, PrefixCachingBlock)
def free(self, block: Block, keep_block_object: bool = False) -> None:
"""Release the block (look at free_block_id(..) docs)
"""
# Release the physical block index
self._free_block_id(block)
# if we comes from promote_to_immutable_block, it means that
# block.content_hash is never None.
# However we need to release the same content block, so that
# physical block could get reused.
if block.block_id != block_id or block.content_hash is None:
refcount = self._refcounter.get(block_id)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
assert block.block_id is not None
refcount = self._refcounter.get(block.block_id)
if refcount == 1:
del self._blocks[block.block_id]
return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id)
# If no longer used, add the block to the evictor.
if refcount == 0:
assert block.content_hash in self._cached_blocks
assert block.block_id is not None
del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed)
# Release the block object to the pool
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
@ -274,17 +359,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
forked_blocks: List[Block] = []
prev_block = None
for block in source_blocks:
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
block_id = block.block_id
assert block_id is not None
forked_blocks.append(
self._create_block(
refcount = self._refcounter.incr(block_id)
assert refcount != 1, "can't fork free'd block_id = {}".format(
block_id)
forked_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
physical_block_id=block_id)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
return forked_blocks
@ -329,7 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block.
block id.
Args:
block: The mutable block to be promoted.
@ -338,23 +426,30 @@ class PrefixCachingBlockAllocator(BlockAllocator):
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
# Ensure block can be promoted
assert block.content_hash is not None
assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if block.content_hash not in self._cached_blocks:
# No cached content hash => Set this block as cached
# (Note that this block is not computed yet =>
# Will be computed after free())
self._cached_blocks[block.content_hash] = block.block_id
else:
self._free_block_id_for_block(
self._cached_blocks[block.content_hash], block)
self._incr_refcount_cached_block(
block, self._cached_blocks[block.content_hash])
return block.block_id
return self._cached_blocks[block.content_hash]
# Reuse the cached content hash
self._decr_refcount_hashless_block(block)
block.block_id = self._cached_blocks[block.content_hash]
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
# Increment refcount of the cached block and (possibly) restore
# it from the evictor.
# Note that in this case, the block is marked as computed
self._incr_refcount_cached_block(block)
return block.block_id
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
@ -362,11 +457,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
BlockId: The block index of the new block if a copy-on-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
@ -386,8 +492,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
for block_id in block_ids:
if block_id in self._blocks:
self._blocks[block_id].last_accessed = now
if self._block_tracker[block_id].active:
self._block_tracker[block_id].last_accessed = now
elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
@ -395,25 +501,46 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching."""
raise NotImplementedError("Marking as computed is incremental")
for block_id in block_ids:
if block_id in self._blocks:
# only those full block is valid for prefix caching
if self._blocks[block_id].is_full:
self._blocks[block_id].computed = True
elif block_id not in self.evictor:
raise ValueError(f"Mark {block_id=} as computed which "
"is not belonged to GPU")
def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None:
assert block_id is not None
self._block_tracker[block_id].enable()
self._block_tracker[block_id].computed = computed
def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_tracker[block_id].disable()
def block_is_computed(self, block_id: int) -> bool:
if block_id in self._blocks:
return self._blocks[block_id].computed
if self._block_tracker[block_id].active:
return self._block_tracker[block_id].computed
else:
return block_id in self.evictor
def get_computed_block_ids(self,
prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool = True) -> List[int]:
prev_prefix_size = len(prev_computed_block_ids)
cur_size = len(block_ids)
if skip_last_block_id:
cur_size -= 1
# Sanity checks
assert cur_size >= 0
assert prev_prefix_size <= cur_size
ret = prev_computed_block_ids
for i in range(prev_prefix_size, cur_size):
block_id = block_ids[i]
if self.block_is_computed(block_id):
ret.append(block_id)
return ret
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Only those blocks that are immutable and already be marked
@ -424,14 +551,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list = [
list(
takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
]
# It returns a list of int although type annotation says list of string.
return commonprefix([
ids for ids in ids_list # type: ignore
ids for ids in computed_seq_block_ids # type: ignore
if ids != []
])
@ -473,10 +595,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped out.
"""
for block in blocks:
self.free(block)
self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None:
"""Execute the swap int actions. Change the block id from
"""Execute the swap in actions. Change the block id from
old allocator to current allocator for each block to finish
the block table update.
@ -484,13 +606,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped in.
"""
for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full:
alloc = self.allocate_immutable(block.prev_block,
block.token_ids)
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids)
else:
alloc = self.allocate_mutable(block.prev_block)
alloc.append_token_ids(block.token_ids)
block.block_id = alloc.block_id
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
class PrefixCachingBlock(Block):
@ -507,7 +638,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (BlockAllocator): The prefix
allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
@ -518,31 +649,55 @@ class PrefixCachingBlock(Block):
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: BlockAllocator,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
assert isinstance(allocator, PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
"PrefixCachingBlockAllocator. Got instead allocator = {}".format(
allocator))
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self._cached_num_tokens_total: int = 0
self._allocator = allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._block = NaiveBlock(
# On the first time, we create the block object, and next we only
# reinitialize it
if hasattr(self, "_block"):
self._block.__init__( # type: ignore[has-type]
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=prefix_caching_allocator,
_cow_target=self,
)
allocator=self._allocator)
else:
self._block = NaiveBlock(prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=self._allocator)
self._update_num_tokens_total()
def _update_num_tokens_total(self):
"""Incrementally computes the number of tokens that there is
till the current block (included)
"""
res = 0
# Add all previous blocks
if self._prev_block is not None:
res += self._prev_block.num_tokens_total
# Add current block
res += len(self.token_ids)
self._cached_num_tokens_total = res
@property
def computed(self) -> bool:
@ -564,22 +719,28 @@ class PrefixCachingBlock(Block):
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
assert token_ids
# Ensure this is mutable block (not promoted)
assert self.content_hash is None
assert not self.computed
# naive block handles CoW.
if len(token_ids) == 0:
return
# Ensure there are input tokens
assert token_ids, "Got token_ids = {}".format(token_ids)
# Naive block handles CoW.
self._block.append_token_ids(token_ids)
self._update_num_tokens_total()
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if self.content_hash is not None:
self.block_id = (self._prefix_caching_allocator.
promote_to_immutable_block(self))
self.block_id = self._allocator.promote_to_immutable_block(self)
@property
def block_id(self) -> Optional[int]:
@ -599,23 +760,6 @@ class PrefixCachingBlock(Block):
@property
def num_tokens_total(self) -> int:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block: Optional[Block] = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while _block is not None:
self._cached_num_tokens_total += len(_block.token_ids)
_block = _block.prev_block
return self._cached_num_tokens_total
@property
@ -638,7 +782,6 @@ class PrefixCachingBlock(Block):
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if self._cached_content_hash is not None:
return self._cached_content_hash
@ -688,7 +831,129 @@ class PrefixCachingBlock(Block):
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
class ComputedBlocksTracker:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
def __init__(self, allocator):
self._allocator = allocator
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
bool]] = {}
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._cached_computed_seq_blocks
self._cached_computed_seq_blocks[seq_id] = ([], False)
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._cached_computed_seq_blocks
del self._cached_computed_seq_blocks[seq_id]
def get_cached_computed_blocks_and_update(
self, seq_id: int, block_ids: List[int]) -> List[int]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert seq_id in self._cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
seq_id]
if has_gap:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks = len(block_ids) - 1
assert num_cur_blocks >= 0
if len(prev_computed_block_ids) >= num_cur_blocks:
# Cache HIT
assert len(prev_computed_block_ids) == num_cur_blocks
return prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
prev_computed_block_ids,
block_ids,
skip_last_block_id=
True, # We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
has_gap = len(computed_block_ids) < num_cur_blocks
# Record
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
has_gap)
return computed_block_ids
class LastAccessBlocksTracker:
"""Manages the last access time of the tracked sequences, in order to allow
an efficient update of allocator's block last access times
"""
def __init__(self, allocator):
self._allocator = allocator
self._seq_last_access: Dict[int, Optional[float]] = {}
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._seq_last_access
self._seq_last_access[seq_id] = None
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._seq_last_access
del self._seq_last_access[seq_id]
def update_last_access(self, seq_id: int, time: float) -> None:
assert seq_id in self._seq_last_access
self._seq_last_access[seq_id] = time
def update_seq_blocks_last_access(self, seq_id: int,
block_ids: List[int]) -> None:
assert seq_id in self._seq_last_access
ts = self._seq_last_access[seq_id]
if ts is None:
# No last access was recorded, no need to update.
return
self._allocator.mark_blocks_as_accessed(block_ids, ts)
def assert_prefix_caching_block_or_none(block: Optional[Block]):
if block is None:
return
assert isinstance(block, PrefixCachingBlock)
assert isinstance(block,
PrefixCachingBlock), "Got block = {}".format(block)

View File

@ -7,6 +7,8 @@ from typing import Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.block_tables: Dict[SeqId, BlockTable] = {}
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table: BlockTable = self._allocate_sequence(seq)
self.block_tables[seq.seq_id] = block_table
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence
#
# NOTE: Here we assume that all sequences in the group have the same
@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return new_cows
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
seq_id = seq.seq_id
if seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id]
# Update seq block ids with the latest access time
self._last_access_blocks_tracker.update_seq_blocks_last_access(
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
# Untrack seq
self._last_access_blocks_tracker.remove_seq(seq_id)
self._computed_blocks_tracker.remove_seq(seq_id)
# Free table/blocks
self.block_tables[seq_id].free()
del self.block_tables[seq_id]
def free_cross(self, seq_group: SequenceGroup) -> None:
request_id = seq_group.request_id
@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
del self.cross_block_tables[request_id]
def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids # type: ignore
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed
# in this step.
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if self.enable_caching:
block_table = self.block_tables[seq.seq_id]
block_ids: List[Optional[int]] = []
for block_id in block_table.physical_block_ids:
block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
# Record the latest access time for the sequence. The actual update
# of the block ids is deferred to the sequence free(..) call, since
# only during freeing of block ids, the blocks are actually added to
# the evictor (which is when the most updated time is required)
# (This avoids expensive calls to mark_blocks_as_accessed(..))
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching,
@ -285,17 +304,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
computed_seq_block_ids = []
for seq in seqs:
computed_seq_block_ids.append(
self._computed_blocks_tracker.
get_cached_computed_blocks_and_update(
seq.seq_id,
self.block_tables[seq.seq_id].physical_block_ids))
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) # type: ignore
computed_seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
"""Returns the AllocStatus for the given sequence_group
@ -323,19 +351,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU.
"""
blocks = self._get_blocks_for_swap(seq_group, SequenceStatus.SWAPPED)
current_swap_mapping = self.block_allocator.swap(
blocks=blocks, source_device=Device.CPU, dest_device=Device.GPU)
physical_block_id_mapping = []
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
continue
block_number_mapping = {
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id):
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id)
for cpu_block_id, gpu_block_id in current_swap_mapping.items()
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
src_device=Device.CPU,
dst_device=Device.GPU)
# Refresh the block ids of the table (post-swap)
self.block_tables[seq.seq_id].update(blocks)
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id):
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id)
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
"""Returns whether we can swap out the given sequence_group
@ -355,7 +395,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return True
return False
def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]:
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots.
@ -366,19 +406,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU.
"""
blocks = self._get_blocks_for_swap(sequence_group,
SequenceStatus.RUNNING)
current_swap_mapping = self.block_allocator.swap(
blocks=blocks, source_device=Device.GPU, dest_device=Device.CPU)
block_number_mapping = {
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id):
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id)
for gpu_block_id, cpu_block_id in current_swap_mapping.items()
physical_block_id_mapping = []
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
continue
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
src_device=Device.GPU,
dst_device=Device.CPU)
# Refresh the block ids of the table (post-swap)
self.block_tables[seq.seq_id].update(blocks)
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id):
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id)
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)

View File

@ -177,7 +177,8 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s)",
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
@ -204,6 +205,8 @@ class LLMEngine:
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
cache_config.enable_prefix_caching,
)
# TODO(woosuk): Print more configs in debug mode.

View File

@ -345,7 +345,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
token_ids = prompt_token_ids + list(output.token_ids)
out_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs is not None else None)
output_text = prompt_text + output.text

View File

@ -427,8 +427,8 @@ class SamplingTensors:
if seq_group.do_sample:
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
prompt_tokens.append(list(seq_data.prompt_token_ids))
output_tokens.append(list(seq_data.output_token_ids))
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,

View File

@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -28,7 +28,7 @@ class CompletionOutput:
index: int
text: str
token_ids: List[int]
token_ids: Tuple[int, ...]
cumulative_logprob: float
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None

View File

@ -116,41 +116,66 @@ class SequenceData:
prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []
self._prompt_token_ids: List[int] = list(prompt_token_ids)
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._output_token_ids: List[int] = (
list(output_token_ids) if output_token_ids is not None else [])
self.prompt_token_ids = prompt_token_ids
self._prompt_token_ids_tuple = tuple(prompt_token_ids)
self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
self._update_cached_all_tokens()
def _update_cached_all_tokens(self):
self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
self._output_token_ids)
@property
def prompt_token_ids(self) -> Tuple[int, ...]:
return self._prompt_token_ids_tuple
@prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = list(new_prompt_token_ids)
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens()
@property
def output_token_ids(self) -> Tuple[int, ...]:
return tuple(self._output_token_ids)
@output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None:
self._output_token_ids = list(new_output_token_ids)
self._update_cached_all_tokens()
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
self._output_token_ids.append(token_id)
self._cached_all_token_ids.append(token_id)
self.cumulative_logprob += logprob
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
return len(self._output_token_ids) + len(self._prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)
return len(self._prompt_token_ids)
def get_output_len(self) -> int:
return len(self.output_token_ids)
return len(self._output_token_ids)
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
return self._cached_all_token_ids
def get_prefix_token_ids(
self, num_tokens: int
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = len(self.prompt_token_ids)
prompt_length = self.get_prompt_len()
if num_tokens > prompt_length:
return (self._prompt_token_ids_tuple,
tuple(self.output_token_ids[:num_tokens - prompt_length]))
tuple(self._output_token_ids[:num_tokens - prompt_length]))
else:
return (self._prompt_token_ids_tuple[:num_tokens], None)
@ -183,14 +208,14 @@ class SequenceData:
return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int:
if not self.output_token_ids:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
if not self._output_token_ids:
return self._prompt_token_ids[-1]
return self._output_token_ids[-1]
def get_prompt_token_ids(self) -> List[int]:
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids
def get_output_token_ids(self) -> List[int]:
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids
@property
@ -199,8 +224,8 @@ class SequenceData:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids}, "
f"prompt_token_ids={self._prompt_token_ids}, "
f"output_token_ids={self._output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})")
@ -306,14 +331,14 @@ class Sequence:
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_prompt_token_ids(self) -> List[int]:
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def get_output_token_ids(self) -> List[int]:
return self.data.output_token_ids
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob