vllm/cacheflow/master/block_manager.py

250 lines
9.8 KiB
Python

from typing import Dict, List, Optional, Set, Tuple
from cacheflow.block import PhysicalTokenBlock
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus
from cacheflow.utils import Device
class BlockAllocator:
def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
if block_size not in [8, 16, 32]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be one of {8, 16, 32}.')
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
# Initialize the free blocks.
# TODO(woosuk): Make this a priority queue.
self.free_blocks = [
PhysicalTokenBlock(device=device, block_number=i, block_size=block_size)
for i in range(num_blocks)
]
def allocate(self) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError('Out of memory! '
f'No more free blocks are available.')
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError('Double free! '
f'The block {block} is already freed.')
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockSpaceManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
self.watermark = watermark
assert watermark >= 0.0
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> bool:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.seqs[0]
num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt.
seq = seq_group.seqs[0]
# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
for _ in range(len(seq.logical_token_blocks)):
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)
# Assign the block table for each sequence.
for seq in seq_group.seqs:
self.block_tables[seq.seq_id] = block_table.copy()
def can_append(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for the new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
if len(block_table) < len(logical_blocks):
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None
# We want to append the token to the last physical block.
last_block = block_table[-1]
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
return None
else:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
new_block = self.gpu_allocator.allocate()
block_table[-1] = new_block
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
for block in src_block_table:
block.ref_count += 1
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
block_table = self.block_tables[seq.seq_id]
for block in block_table:
blocks.add(block)
return list(blocks)
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append().
num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
for cpu_block in block_table:
if cpu_block in mapping:
gpu_block = mapping[cpu_block]
gpu_block.ref_count += 1
else:
gpu_block = self.gpu_allocator.allocate()
mapping[cpu_block] = gpu_block
new_block_table.append(gpu_block)
# Free the CPU block swapped in to GPU.
self.cpu_allocator.free(cpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = {
cpu_block.block_number: gpu_block.block_number
for cpu_block, gpu_block in mapping.items()
}
return block_number_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table:
if gpu_block in mapping:
cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1
else:
cpu_block = self.cpu_allocator.allocate()
mapping[gpu_block] = cpu_block
new_block_table.append(cpu_block)
# Free the GPU block swapped out to CPU.
self.gpu_allocator.free(gpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = {
gpu_block.block_number: cpu_block.block_number
for gpu_block, cpu_block in mapping.items()
}
return block_number_mapping
def _free_block_table(self, block_table: BlockTable) -> None:
for block in block_table:
if block.device == Device.GPU:
self.gpu_allocator.free(block)
else:
self.cpu_allocator.free(block)
def free(self, seq: Sequence) -> None:
block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
def reset(self) -> None:
for block_table in self.block_tables.values():
self._free_block_table(block_table)
self.block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table]
def get_num_free_gpu_blocks(self) -> int:
return self.gpu_allocator.get_num_free_blocks()
def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks()