[Bugfix] Block manager v2 with preemption and lookahead slots (#8824)

This commit is contained in:
sroy745 2024-09-28 18:17:45 -07:00 committed by GitHub
parent d1537039ce
commit 5bf8789b2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 133 additions and 116 deletions

View File

@ -23,8 +23,10 @@ MODELS = [
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def check_settings(): def check_settings():
assert ENABLE_ARTIFICIAL_PREEMPT is True, ( assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest "
"tests/basic_correctness/test_preemption.py`") "tests/basic_correctness/test_preemption.py`")
@ -199,6 +201,7 @@ def test_swap(
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4]) @pytest.mark.parametrize("beam_width", [4])
@pytest.mark.parametrize("use_v2_block_manager", [True, False])
def test_swap_infeasible( def test_swap_infeasible(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
@ -207,6 +210,7 @@ def test_swap_infeasible(
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
worker_use_ray: bool, worker_use_ray: bool,
use_v2_block_manager: bool,
) -> None: ) -> None:
"""Verify infeasible swap request will be ignored.""" """Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16 BLOCK_SIZE = 16
@ -223,6 +227,7 @@ def test_swap_infeasible(
num_gpu_blocks_override=prefill_blocks + decode_blocks, num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray, worker_use_ray=worker_use_ray,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model: ) as vllm_model:
sampling_params = SamplingParams(n=beam_width, sampling_params = SamplingParams(n=beam_width,
use_beam_search=True, use_beam_search=True,

View File

@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
seq_group, num_lookahead_slots) == AllocStatus.NEVER seq_group, num_lookahead_slots) == AllocStatus.NEVER
@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10])
@pytest.mark.parametrize("enable_caching", [False, True])
def test_swap_in_infeasible(num_lookahead_slots, enable_caching):
"""Verifies that swapping fails if there is not enough free blocks
to account for unseen tokens and lookahead_slots.
"""
block_size = 8
num_cpu_blocks = 1
num_gpu_blocks = 1
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt_length = block_size - 3
assert prompt_length > 0
prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
prompt.status = SequenceStatus.RUNNING
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
# Swap seq group from GPU -> CPU.
assert block_manager.can_swap_out(seq_group)
block_manager.swap_out(seq_group)
prompt.status = SequenceStatus.SWAPPED
# Swap seq group from CPU -> GPU.
# The number of unseen tokens is 1. If the number of existing
# tokens plus the unseen ones and number of lookahead slots exceeds
# the total number of available GPU blocks then the swap
# should fail.
num_unseen_tokens = 1
if (num_lookahead_slots + num_unseen_tokens +
prompt_length) <= (block_size * num_gpu_blocks):
assert block_manager.can_swap_in(seq_group,
num_lookahead_slots) == AllocStatus.OK
else:
assert block_manager.can_swap_in(
seq_group, num_lookahead_slots) == AllocStatus.NEVER
# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. # TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.
@ -400,7 +446,6 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append,
if max_n is None: if max_n is None:
max_n = min_n max_n = min_n
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
#print("check", min_n, used, max_n)
assert min_n <= used assert min_n <= used
assert used <= max_n assert used <= max_n

View File

@ -104,9 +104,9 @@ class TestNaiveBlockAllocator:
@staticmethod @staticmethod
@pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8]) @pytest.mark.parametrize("block_size", [8])
def test_naive_block_get_num_blocks_touched(num_blocks, block_size): def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size):
""" Verify the allocator can correctly return the number of """ Verify the allocator can correctly return the number of
blocks touched, with different lookahead slots. full blocks touched.
""" """
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock, allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks, num_blocks=num_blocks,
@ -124,7 +124,7 @@ class TestNaiveBlockAllocator:
src_blocks = [allocate_block() for _ in range(num_blocks - 1)] src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
# All blocks are cached # All blocks are cached
assert allocator_dst.get_num_blocks_touched( assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks - 1 src_blocks) == num_blocks - 1
# Insert one non-full block in the src # Insert one non-full block in the src
@ -136,9 +136,10 @@ class TestNaiveBlockAllocator:
src_blocks.append(allocate_non_full_block()) src_blocks.append(allocate_non_full_block())
src_blocks[-1].append_token_ids([0]) src_blocks[-1].append_token_ids([0])
assert allocator_dst.get_num_blocks_touched( assert allocator_dst.get_num_full_blocks_touched(
src_blocks, num_lookahead_slots=1) == num_blocks src_blocks) == num_blocks - 1
assert allocator_dst.get_num_blocks_touched( # Fill up the last source block and then invoke
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks # get_num_blocks_touched
assert allocator_dst.get_num_blocks_touched( src_blocks[-1].append_token_ids([0] * (block_size - 1))
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1) assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks

View File

@ -318,11 +318,10 @@ class TestPrefixCachingBlockAllocator:
@staticmethod @staticmethod
@pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8]) @pytest.mark.parametrize("block_size", [8])
def test_prefix_caching_block_get_num_blocks_touched( def test_prefix_caching_block_get_num_full_blocks_touched(
num_blocks, block_size): num_blocks, block_size):
""" Verify the allocator can correctly return the number of """ Verify the allocator can correctly return the number of
blocks touched, when there are cached prefixes and different blocks touched, when there are cached prefixes.
lookahead slots.
""" """
allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks, allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size) block_size=block_size)
@ -346,28 +345,30 @@ class TestPrefixCachingBlockAllocator:
token_ids=token_ids, token_ids=token_ids,
allocator=allocator_src, allocator=allocator_src,
) )
# All blocks are cached # All blocks are cached
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0 assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 0
# Free the first block in the dst # Free the first block in the dst
allocator_dst.free(cached_blocks[0]) allocator_dst.free(cached_blocks[0])
# Now the first block becomes dangling, the swapped blocks need # Now the first block becomes dangling, the swapped blocks need
# to reclaim the first block in the dst # to reclaim the first block in the dst
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1 assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 1
# Insert one non-full block in the src # Insert one non-full block in the src
non_full_block = allocator_src.allocate_mutable_block( non_full_block = allocator_src.allocate_mutable_block(
blocks_to_swap_in[-1]) blocks_to_swap_in[-1])
non_full_block.append_token_ids([0]) non_full_block.append_token_ids([0])
blocks_to_swap_in.append(non_full_block) blocks_to_swap_in.append(non_full_block)
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in, assert allocator_dst.get_num_full_blocks_touched(
num_lookahead_slots=1) == 2 blocks_to_swap_in) == 1
assert allocator_dst.get_num_blocks_touched( # Fill up the last mutable block and invoke get_num_blocks_touched.
blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2 # Note: The last block is not cached so it will be touched.
assert allocator_dst.get_num_blocks_touched( non_full_block.append_token_ids([0] * (block_size - 1))
blocks_to_swap_in, num_lookahead_slots=block_size) == 3 assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 2
@staticmethod @staticmethod
@pytest.mark.parametrize("num_blocks", [1024]) @pytest.mark.parametrize("num_blocks", [1024])

View File

@ -259,25 +259,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
current_swap_mapping[src_block_id] = dst_block_id current_swap_mapping[src_block_id] = dst_block_id
return current_swap_mapping return current_swap_mapping
def get_num_blocks_touched(self, def get_num_full_blocks_touched(self, blocks: List[Block],
blocks: List[Block], device: Device) -> int:
device: Device, """Returns the number of full blocks that will be touched by
num_lookahead_slots: int = 0) -> int:
"""Returns the number of blocks that will be touched by
swapping in/out the given blocks on to the 'device'. swapping in/out the given blocks on to the 'device'.
Args: Args:
blocks: List of blocks to be swapped. blocks: List of blocks to be swapped.
device (Device): Device to swap the 'blocks' on. device (Device): Device to swap the 'blocks' on.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns: Returns:
int: the number of blocks that will be touched by int: the number of full blocks that will be touched by
swapping in/out the given blocks on to the 'device'. swapping in/out the given blocks on to the 'device'.
Non full blocks are ignored when deciding the number
of blocks to touch.
""" """
return self._allocators[device].get_num_blocks_touched( return self._allocators[device].get_num_full_blocks_touched(blocks)
blocks, num_lookahead_slots)
def clear_copy_on_writes(self) -> List[Tuple[int, int]]: def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of """Clears the copy-on-write (CoW) state and returns the mapping of

View File

@ -181,9 +181,7 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_blocks_touched(self, def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
pass pass
@abstractmethod @abstractmethod
@ -260,10 +258,8 @@ class DeviceAwareBlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_blocks_touched(self, def get_num_full_blocks_touched(self, blocks: List[Block],
blocks: List[Block], device: Device) -> int:
device: Device,
num_lookahead_slots: int = 0) -> int:
pass pass
@abstractmethod @abstractmethod

View File

@ -4,7 +4,6 @@ from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.utils import cdiv
Refcount = int Refcount = int
@ -282,40 +281,26 @@ class NaiveBlockAllocator(BlockAllocator):
def promote_to_immutable_block(self, block: Block) -> BlockId: def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError("There is no promotion for naive blocks") raise NotImplementedError("There is no promotion for naive blocks")
def get_num_blocks_touched(self, def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
blocks: List[Block], """Returns the number of full blocks that will be touched by
num_lookahead_slots: int = 0) -> int: swapping in/out.
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
Args: Args:
blocks (List[Block]): The potential blocks to swap. blocks: List of blocks to be swapped.
num_lookahead_slots (int): number of lookahead slots (0 for swap
out).
Returns: Returns:
int: the number of blocks that will be touched by int: the number of full blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots. swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
""" """
# NOTE: for naive block, we use set to eliminate common blocks among # NOTE: for naive block, we use set to eliminate common blocks among
# seqs, also we compare the empty slots in the mutable blocks with # seqs, also we compare the empty slots in the mutable blocks with
# lookahead slots to get the number of unique new block that are # lookahead slots to get the number of unique new block that are
# needed. # needed.
old_block_set = set() old_block_set = set()
new_block_count = 0
# TODO(cade): make sure the logic is correct and clean it up.
for block in blocks: for block in blocks:
if not block.is_full and num_lookahead_slots != 0: if block.is_full:
new_block_count += 1 old_block_set.add(block)
if num_lookahead_slots > block.num_empty_slots: return len(old_block_set)
new_block_count += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
old_block_set.add(block.block_id)
num_touched_blocks = new_block_count + len(old_block_set)
return num_touched_blocks
def swap_out(self, blocks: List[Block]) -> None: def swap_out(self, blocks: List[Block]) -> None:
for block in blocks: for block in blocks:

View File

@ -8,7 +8,6 @@ from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock, from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator) NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.utils import cdiv
PrefixHash = int PrefixHash = int
@ -576,37 +575,27 @@ class PrefixCachingBlockAllocator(BlockAllocator):
if ids if ids
]) ])
def get_num_blocks_touched(self, def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
blocks: List[Block], """Returns the number of full blocks that will be touched by
num_lookahead_slots: int = 0) -> int: swapping in/out.
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
Args: Args:
blocks (List[Block]): The potential blocks to swap. blocks: List of blocks to be swapped.
num_lookahead_slots (int): number of lookahead slots (0 for
swap out).
Returns: Returns:
int: the number of blocks that will be touched by int: the number of full blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots. swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
""" """
num_touched_blocks = 0 num_touched_blocks: int = 0
for block in blocks: for block in blocks:
if not block.is_full: # If the block has a match in the cache and the cached
# block is not referenced, then we still count it as a
# touched block
if block.is_full and (not self.is_block_cached(block) or \
(block.content_hash is not None and \
self._cached_blocks[block.content_hash] in \
self.evictor)):
num_touched_blocks += 1 num_touched_blocks += 1
if num_lookahead_slots > block.num_empty_slots:
num_touched_blocks += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
# If the block has a match in the cache and the cached block
# is not referenced, then we still count it as a touched block
if not self.is_block_cached(block) or \
(block.content_hash is not None and \
self._cached_blocks[block.content_hash] in self.evictor):
num_touched_blocks += 1
return num_touched_blocks return num_touched_blocks
def swap_out(self, blocks: List[Block]) -> None: def swap_out(self, blocks: List[Block]) -> None:

View File

@ -1,5 +1,4 @@
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from itertools import chain
from typing import Dict, List, Optional from typing import Dict, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
@ -470,12 +469,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
AllocStatus: The AllocStatus for swapping in/out the given AllocStatus: The AllocStatus for swapping in/out the given
sequence_group on to the 'device'. sequence_group on to the 'device'.
""" """
blocks = self._get_blocks_for_swap(seq_group, status) # First determine the number of blocks that will be touched by this
num_blocks_touched = self.block_allocator.get_num_blocks_touched( # swap. Then verify if there are available blocks in the device
blocks, device, num_lookahead_slots) # to perform the swap.
num_blocks_touched = 0
blocks: List[Block] = []
for seq in seq_group.get_seqs(status=status):
block_table = self.block_tables[seq.seq_id]
if block_table.blocks is not None:
# Compute the number blocks to touch for the tokens to be
# appended. This does NOT include the full blocks that need
# to be touched for the swap.
num_blocks_touched += \
block_table.get_num_blocks_touched_by_append_slots(
block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots)
blocks.extend(block_table.blocks)
# Compute the number of full blocks to touch and add it to the
# existing count of blocks to touch.
num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
blocks, device=device)
watermark_blocks = 0 watermark_blocks = 0
if device == Device.GPU: if device == Device.GPU:
watermark_blocks = self.watermark_blocks watermark_blocks = self.watermark_blocks
if self.block_allocator.get_num_total_blocks( if self.block_allocator.get_num_total_blocks(
device) < num_blocks_touched: device) < num_blocks_touched:
return AllocStatus.NEVER return AllocStatus.NEVER
@ -484,23 +502,3 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return AllocStatus.OK return AllocStatus.OK
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def _get_blocks_for_swap(self, seq_group: SequenceGroup,
status: SequenceStatus) -> List[Block]:
"""Returns the list of blocks those are touched by the seq_group
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
Returns:
The list of blocks those are touched by the seq_group.
"""
blocks: Dict[int, List[Block]] = {}
for seq in seq_group.get_seqs(status=status):
block_table = self.block_tables[seq.seq_id]
if block_table.blocks is not None:
blocks[seq.seq_id] = block_table.blocks
combined_blocks = list(chain(*blocks.values()))
return combined_blocks