Prefix Cache Aware Scheduling [1/n] (#10128)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
parent
7c25fe45a6
commit
4634a89d18
@ -5,9 +5,14 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.core.utils import create_dummy_sequence
|
||||||
|
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||||
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
|
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||||
|
PrefixCachingBlock,
|
||||||
PrefixCachingBlockAllocator)
|
PrefixCachingBlockAllocator)
|
||||||
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.utils import Device
|
||||||
|
|
||||||
|
|
||||||
class TestPrefixCachingBlock:
|
class TestPrefixCachingBlock:
|
||||||
@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
token_ids=common_token_ids,
|
token_ids=common_token_ids,
|
||||||
allocator=allocator,
|
allocator=allocator,
|
||||||
)
|
)
|
||||||
block_ids = [block.block_id for block in blocks]
|
block_hashes = [block.content_hash for block in blocks]
|
||||||
# The allocated blocks should be marked as touched
|
# The allocated blocks should be marked as touched
|
||||||
# but not computed.
|
# but not computed.
|
||||||
computed_block_ids = allocator.get_computed_block_ids(
|
computed_block_ids = allocator.find_cached_blocks_prefix(
|
||||||
[], block_ids, skip_last_block_id=False)
|
block_hashes)
|
||||||
assert len(computed_block_ids) == 0
|
assert len(computed_block_ids) == 0
|
||||||
|
|
||||||
allocator.mark_blocks_as_computed([])
|
allocator.mark_blocks_as_computed([])
|
||||||
computed_block_ids = allocator.get_computed_block_ids(
|
computed_block_ids = allocator.find_cached_blocks_prefix(
|
||||||
[], block_ids, skip_last_block_id=False)
|
block_hashes=block_hashes)
|
||||||
assert len(computed_block_ids) == common_blocks
|
assert len(computed_block_ids) == common_blocks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_find_cached_blocks_prefix():
|
||||||
|
"""
|
||||||
|
This test verifies the behavior of find_cached_blocks_prefix.
|
||||||
|
"""
|
||||||
|
block_size = 4
|
||||||
|
num_blocks = 8
|
||||||
|
total_test_blocks = 12
|
||||||
|
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||||
|
block_size=block_size)
|
||||||
|
|
||||||
|
token_ids = list(range(total_test_blocks * block_size))
|
||||||
|
block_tokens_seq1 = token_ids[:num_blocks * block_size]
|
||||||
|
blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=block_tokens_seq1,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
block_hashes_seq1 = [block.content_hash for block in blocks_seq1]
|
||||||
|
allocator.mark_blocks_as_computed([])
|
||||||
|
|
||||||
|
# All blocks should be cached.
|
||||||
|
cached_blocks_seq1 = allocator.find_cached_blocks_prefix(
|
||||||
|
block_hashes=block_hashes_seq1)
|
||||||
|
assert len(cached_blocks_seq1) == num_blocks
|
||||||
|
|
||||||
|
# Free the first sequence.
|
||||||
|
for block in blocks_seq1:
|
||||||
|
allocator.free(block)
|
||||||
|
|
||||||
|
# All blocks should be still be cached if not required to be allocated.
|
||||||
|
cached_blocks = allocator.find_cached_blocks_prefix(
|
||||||
|
block_hashes=block_hashes_seq1)
|
||||||
|
assert len(cached_blocks) == num_blocks
|
||||||
|
|
||||||
|
block_tokens_seq2 = token_ids[num_blocks * block_size:]
|
||||||
|
blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=block_tokens_seq2,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
block_hashes_seq2 = [block.content_hash for block in blocks_seq2]
|
||||||
|
allocator.mark_blocks_as_computed([])
|
||||||
|
cached_blocks = allocator.find_cached_blocks_prefix(
|
||||||
|
block_hashes=block_hashes_seq2)
|
||||||
|
assert len(cached_blocks) == len(blocks_seq2)
|
||||||
|
|
||||||
|
# Half of the blocks from seq1 should still be cached.
|
||||||
|
num_evicted_blocks = len(blocks_seq2)
|
||||||
|
cached_blocks = allocator.find_cached_blocks_prefix(
|
||||||
|
block_hashes=block_hashes_seq1)
|
||||||
|
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_immutable_chain(
|
def create_immutable_chain(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
blocks.append(prev_block)
|
blocks.append(prev_block)
|
||||||
|
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputedBlocksTracker:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_mock_allocator():
|
||||||
|
return MagicMock(spec=PrefixCachingBlockAllocator)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_get_num_cached_tokens():
|
||||||
|
"""
|
||||||
|
Test it correctly computes the number of cached tokens for a given
|
||||||
|
sequence:
|
||||||
|
|
||||||
|
- The cache token count is derived from the number of cached blocks.
|
||||||
|
- The cache token count is updated when the allocator is updated.
|
||||||
|
- When a sequence is removed, the cache token count should be updated
|
||||||
|
accordingly.
|
||||||
|
|
||||||
|
# TODO(rickyx): This behaviour for prefill sequence is a hack until
|
||||||
|
we fix the computed blocks tracking.
|
||||||
|
- The cache token count for prefill sequence doesn't change while
|
||||||
|
the sequence is in continuous prefill (chunked prefill).
|
||||||
|
"""
|
||||||
|
block_size = 4
|
||||||
|
mock_allocator = TestComputedBlocksTracker._get_mock_allocator()
|
||||||
|
tracker = ComputedBlocksTracker(
|
||||||
|
allocator=mock_allocator,
|
||||||
|
block_size=block_size,
|
||||||
|
enable_caching=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Not yet allocated.
|
||||||
|
tokens = [0, 1, 2, 3, 4, 5]
|
||||||
|
seq1 = create_dummy_sequence(request_id=0,
|
||||||
|
token_ids=tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
mock_allocator.find_cached_blocks_prefix.return_value = []
|
||||||
|
assert tracker.get_num_cached_tokens(seq1) == 0
|
||||||
|
|
||||||
|
mock_allocator.find_cached_blocks_prefix.return_value = [
|
||||||
|
None
|
||||||
|
] # 1 block cached.
|
||||||
|
# Result is cached for prefill sequence.
|
||||||
|
assert tracker.get_num_cached_tokens(seq1) == 0
|
||||||
|
|
||||||
|
# Mark the sequence as non-prefill.
|
||||||
|
seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed.
|
||||||
|
assert not seq1.is_prefill()
|
||||||
|
|
||||||
|
# Recomputes for decoding sequence.
|
||||||
|
assert tracker.get_num_cached_tokens(seq1) == 4
|
||||||
|
|
||||||
|
# Append new tokens to the sequence.
|
||||||
|
num_new_tokens = 3
|
||||||
|
for i in range(num_new_tokens):
|
||||||
|
seq1.append_token_id(i, {i: Logprob(logprob=0.0)})
|
||||||
|
|
||||||
|
assert tracker.get_num_cached_tokens(seq1) == 4
|
||||||
|
|
||||||
|
# Update the allocator.
|
||||||
|
mock_allocator.find_cached_blocks_prefix.return_value = [
|
||||||
|
None
|
||||||
|
] * 2 # 2 blocks cached.
|
||||||
|
assert tracker.get_num_cached_tokens(seq1) == 8
|
||||||
|
|
||||||
|
# Remove the sequence.
|
||||||
|
tracker.remove_seq(seq1.seq_id)
|
||||||
|
|
||||||
|
# Re-create the sequence with the same request id to simulate recompute.
|
||||||
|
seq1 = create_dummy_sequence(request_id=0,
|
||||||
|
token_ids=tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
mock_allocator.find_cached_blocks_prefix.return_value = [
|
||||||
|
] # no cached block
|
||||||
|
assert tracker.get_num_cached_tokens(seq1) == 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_correct_block_hash():
|
||||||
|
"""
|
||||||
|
Test that the block hash is correctly computed for a sequence (should
|
||||||
|
match the underlying block allocator's block hash). So the number of
|
||||||
|
cached tokens is correctly retrieved.
|
||||||
|
"""
|
||||||
|
block_size = 4
|
||||||
|
allocator = CpuGpuBlockAllocator.create(
|
||||||
|
allocator_type="prefix_caching",
|
||||||
|
num_gpu_blocks=16,
|
||||||
|
num_cpu_blocks=16,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
gpu_allocator = allocator._allocators[Device.GPU]
|
||||||
|
|
||||||
|
tracker = ComputedBlocksTracker(
|
||||||
|
allocator=allocator,
|
||||||
|
block_size=block_size,
|
||||||
|
enable_caching=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = list(range(block_size * 4)) # 4 blocks.
|
||||||
|
seq = create_dummy_sequence(request_id=0,
|
||||||
|
token_ids=tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=tokens,
|
||||||
|
allocator=gpu_allocator,
|
||||||
|
)
|
||||||
|
allocator.mark_blocks_as_computed([])
|
||||||
|
|
||||||
|
assert tracker.get_num_cached_tokens(seq) == len(tokens)
|
||||||
|
@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SequenceGroup
|
from vllm.sequence import SequenceGroup
|
||||||
|
|
||||||
from .utils import (append_new_token, append_new_token_seq_group,
|
from .utils import (append_new_token, append_new_token_seq,
|
||||||
create_dummy_prompt, get_sequence_groups,
|
append_new_token_seq_group, create_dummy_prompt,
|
||||||
schedule_and_update_computed_tokens)
|
get_sequence_groups, schedule_and_update_computed_tokens)
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_add_seq_group():
|
def test_scheduler_add_seq_group():
|
||||||
@ -305,6 +305,8 @@ def initialize_scheduler(
|
|||||||
block_size=4,
|
block_size=4,
|
||||||
num_cpu_blocks=8,
|
num_cpu_blocks=8,
|
||||||
num_gpu_blocks=8,
|
num_gpu_blocks=8,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
enable_chunked_prefill=False,
|
||||||
):
|
):
|
||||||
block_size = block_size
|
block_size = block_size
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
@ -312,8 +314,15 @@ def initialize_scheduler(
|
|||||||
max_num_batched_tokens=max_token_budget,
|
max_num_batched_tokens=max_token_budget,
|
||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
)
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size,
|
||||||
|
1.0,
|
||||||
|
1,
|
||||||
|
"auto",
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
|
||||||
cache_config.num_cpu_blocks = num_cpu_blocks
|
cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||||
@ -800,3 +809,165 @@ def test_scheduling_budget():
|
|||||||
assert budget.num_curr_seqs == 0
|
assert budget.num_curr_seqs == 0
|
||||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||||
assert budget.num_curr_seqs == 0
|
assert budget.num_curr_seqs == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||||
|
def test_prefix_caching_aware_prefills(enable_prefix_caching):
|
||||||
|
"""
|
||||||
|
Test the below scenario:
|
||||||
|
|
||||||
|
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
|
||||||
|
|
||||||
|
The test verifies the below scenarios:
|
||||||
|
1. SeqA is first scheduled.
|
||||||
|
2. SeqB and SeqC can be prefilled together in a single schedule round
|
||||||
|
even though there are not enough token budgets to prefill both without
|
||||||
|
considering prefix caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_size = 4
|
||||||
|
max_num_batched_tokens = 12
|
||||||
|
max_seq_group = 3
|
||||||
|
scheduler = initialize_scheduler(
|
||||||
|
block_size=block_size,
|
||||||
|
num_cpu_blocks=16,
|
||||||
|
num_gpu_blocks=16,
|
||||||
|
max_token_budget=max_num_batched_tokens,
|
||||||
|
max_num_seqs=max_seq_group,
|
||||||
|
max_model_len=max_num_batched_tokens,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
)
|
||||||
|
|
||||||
|
seqA_tokens = list(range(8))
|
||||||
|
num_shared_tokens = 4
|
||||||
|
seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
|
||||||
|
12, 16)) # Shared prefix first 4.
|
||||||
|
seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
|
||||||
|
16, 20)) # Shared prefix first 4.
|
||||||
|
|
||||||
|
seqA, seqA_group = create_dummy_prompt("0",
|
||||||
|
prompt_tokens=seqA_tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
seqB, seqB_group = create_dummy_prompt("1",
|
||||||
|
prompt_tokens=seqB_tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
seqC, seqC_group = create_dummy_prompt("2",
|
||||||
|
prompt_tokens=seqC_tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
|
||||||
|
# Schedule seqA prefill.
|
||||||
|
scheduler.add_seq_group(seqA_group)
|
||||||
|
metas, out, _ = scheduler.schedule()
|
||||||
|
assert (len(out.scheduled_seq_groups) == 1
|
||||||
|
and out.scheduled_seq_groups[0].seq_group == seqA_group)
|
||||||
|
assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)
|
||||||
|
|
||||||
|
# Schedule seqA decode.
|
||||||
|
append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
|
||||||
|
metas, out, _ = scheduler.schedule()
|
||||||
|
|
||||||
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
|
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||||
|
assert out.scheduled_seq_groups[0].token_chunk_size == 1
|
||||||
|
|
||||||
|
# Schedule seqB and seqC prefills should work with prefix caching.
|
||||||
|
scheduler.add_seq_group(seqB_group)
|
||||||
|
scheduler.add_seq_group(seqC_group)
|
||||||
|
metas, out, _ = scheduler.schedule()
|
||||||
|
|
||||||
|
if enable_prefix_caching:
|
||||||
|
assert len(out.scheduled_seq_groups) == 2
|
||||||
|
assert set([
|
||||||
|
out.scheduled_seq_groups[0].seq_group,
|
||||||
|
out.scheduled_seq_groups[1].seq_group,
|
||||||
|
]) == set([seqB_group, seqC_group])
|
||||||
|
assert len(metas) == 2
|
||||||
|
for meta in metas:
|
||||||
|
assert meta.token_chunk_size == 8
|
||||||
|
assert (len(meta.computed_block_nums) == num_shared_tokens //
|
||||||
|
block_size) # 1 Block for the 8 tokens.
|
||||||
|
else:
|
||||||
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
|
assert len(metas) == 1
|
||||||
|
assert metas[0].token_chunk_size == 8
|
||||||
|
assert len(metas[0].computed_block_nums) == 0 # No blocks computed.
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This test verifies that we don't schedule new prefills if there's already
|
||||||
|
a continuous prefill in progress even though the new prefills with shared
|
||||||
|
prefix can fit in the token budget:
|
||||||
|
|
||||||
|
- SeqA is being chunked prefill.
|
||||||
|
- SeqB with the same prompt shouldn't be scheduled for prefill even though
|
||||||
|
there's enough token budget to prefill the cached tokens.
|
||||||
|
- Neither should seqC be scheduled.
|
||||||
|
|
||||||
|
- When seqA is in decoding phase, seqB and seqC can be scheduled.
|
||||||
|
- Entire seqB should be prefilled since it's a full prefix cache hit.
|
||||||
|
- SeqC would be partially prefilled with the prefix shared, and the
|
||||||
|
remaining unique tokens would be prefilled (rounded down to be
|
||||||
|
block-size aligned).
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_size = 2
|
||||||
|
max_num_batched_tokens = 4
|
||||||
|
max_seq_group = 3
|
||||||
|
scheduler = initialize_scheduler(
|
||||||
|
block_size=block_size,
|
||||||
|
num_cpu_blocks=16,
|
||||||
|
num_gpu_blocks=16,
|
||||||
|
max_token_budget=max_num_batched_tokens,
|
||||||
|
max_num_seqs=max_seq_group,
|
||||||
|
max_model_len=100,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
seqA_tokens = list(range(8))
|
||||||
|
seqB_tokens = seqA_tokens
|
||||||
|
seqC_shared_prefix_len = 4
|
||||||
|
seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))
|
||||||
|
|
||||||
|
seqA, seqA_group = create_dummy_prompt("0",
|
||||||
|
prompt_tokens=seqA_tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
seqB, seqB_group = create_dummy_prompt("1",
|
||||||
|
prompt_tokens=seqB_tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
|
||||||
|
# Chunked prefill seqA.
|
||||||
|
scheduler.add_seq_group(seqA_group)
|
||||||
|
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
|
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||||
|
assert out.scheduled_seq_groups[0].token_chunk_size == 4
|
||||||
|
|
||||||
|
# seqB should not be scheduled with ongoing prefills.
|
||||||
|
scheduler.add_seq_group(seqB_group)
|
||||||
|
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
|
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||||
|
assert out.scheduled_seq_groups[0].token_chunk_size == 4
|
||||||
|
|
||||||
|
# both seqB and seqC can now be scheduled with seqA is over.
|
||||||
|
# seqA is in decoding phase.
|
||||||
|
append_new_token_seq(seqA, 999)
|
||||||
|
seqC, seqC_group = create_dummy_prompt("2",
|
||||||
|
prompt_tokens=seqC_tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
scheduler.add_seq_group(seqC_group)
|
||||||
|
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert len(out.scheduled_seq_groups) == 3
|
||||||
|
|
||||||
|
metas = {meta.request_id: meta for meta in metas}
|
||||||
|
assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode
|
||||||
|
assert (metas[seqB_group.request_id].token_chunk_size == 8
|
||||||
|
) # Fully cached prefill
|
||||||
|
assert (
|
||||||
|
metas[seqC_group.request_id].token_chunk_size == 6
|
||||||
|
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||||
|
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||||
|
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_prompt(
|
def create_dummy_prompt(
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt_length: int,
|
prompt_length: int = -1,
|
||||||
block_size: Optional[int] = None,
|
block_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
best_of: int = 1,
|
best_of: int = 1,
|
||||||
@ -26,6 +29,7 @@ def create_dummy_prompt(
|
|||||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||||
# and prompt "0 ... block_size".
|
# and prompt "0 ... block_size".
|
||||||
prompt_tokens = list(range(prompt_length))
|
prompt_tokens = list(range(prompt_length))
|
||||||
|
|
||||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||||
prompt = Sequence(int(request_id),
|
prompt = Sequence(int(request_id),
|
||||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||||
@ -42,6 +46,15 @@ def create_dummy_prompt(
|
|||||||
return prompt, seq_group
|
return prompt, seq_group
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_sequence(request_id: int, token_ids: List[int],
|
||||||
|
block_size: int) -> Sequence:
|
||||||
|
return Sequence(
|
||||||
|
seq_id=request_id,
|
||||||
|
inputs=token_inputs(token_ids),
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_prompt_encoder_decoder(
|
def create_dummy_prompt_encoder_decoder(
|
||||||
request_id: str,
|
request_id: str,
|
||||||
decoder_prompt_length: int,
|
decoder_prompt_length: int,
|
||||||
@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
|
|||||||
|
|
||||||
def schedule_and_update_computed_tokens(scheduler):
|
def schedule_and_update_computed_tokens(scheduler):
|
||||||
metas, out, _ = scheduler.schedule()
|
metas, out, _ = scheduler.schedule()
|
||||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
for s in out.scheduled_seq_groups:
|
||||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
s.seq_group.update_num_computed_tokens(s.token_chunk_size)
|
||||||
return metas, out
|
return metas, out
|
||||||
|
|
||||||
|
|
||||||
|
def append_new_token_seq(seq: Sequence, token_id: int):
|
||||||
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||||
|
|
||||||
|
|
||||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerProxy:
|
||||||
|
"""
|
||||||
|
A proxy class to forward calls to the scheduler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scheduler: Scheduler):
|
||||||
|
self.scheduler_ = scheduler
|
||||||
|
self.call_history: Dict[str, List[Any]] = defaultdict(list)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
result = getattr(self.scheduler_, name)(*args, **kwargs)
|
||||||
|
self.call_history[name].append((args, kwargs, result))
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def last_schedule_ret(
|
||||||
|
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
|
||||||
|
_, _, ret = self.call_history["schedule"][-1]
|
||||||
|
return ret
|
||||||
|
@ -2,10 +2,15 @@
|
|||||||
|
|
||||||
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.conftest import VllmRunner
|
||||||
|
from tests.core.utils import SchedulerProxy, create_dummy_prompt
|
||||||
from tests.kernels.utils import override_backend_env_variable
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
from vllm import SamplingParams, TokensPrompt
|
from vllm import SamplingParams, TokensPrompt
|
||||||
|
from vllm.core.scheduler import Scheduler
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_outputs_equal
|
||||||
|
|
||||||
@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
|
|||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("cached_position", [0, 1])
|
@pytest.mark.parametrize("cached_position", [0, 1])
|
||||||
|
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
||||||
@pytest.mark.parametrize("block_size", [16])
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
def test_mixed_requests(
|
def test_mixed_requests(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
@ -37,6 +43,7 @@ def test_mixed_requests(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
cached_position: int,
|
cached_position: int,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -55,6 +62,7 @@ def test_mixed_requests(
|
|||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
# Run the first prompt so the cache is populated
|
# Run the first prompt so the cache is populated
|
||||||
@ -72,13 +80,13 @@ def test_mixed_requests(
|
|||||||
block_size) * block_size
|
block_size) * block_size
|
||||||
else:
|
else:
|
||||||
expected_num_cached_tokens = 0
|
expected_num_cached_tokens = 0
|
||||||
assert req_outputs[
|
assert (
|
||||||
i].num_cached_tokens == expected_num_cached_tokens
|
req_outputs[i].num_cached_tokens == expected_num_cached_tokens)
|
||||||
|
|
||||||
vllm_outputs = [
|
vllm_outputs = [(
|
||||||
(output.prompt_token_ids + list(output.outputs[0].token_ids),
|
output.prompt_token_ids + list(output.outputs[0].token_ids),
|
||||||
output.prompt + output.outputs[0].text) for output in req_outputs
|
output.prompt + output.outputs[0].text,
|
||||||
]
|
) for output in req_outputs]
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
|
|||||||
for prompt in UNSTABLE_PROMPT_SEQUENCE:
|
for prompt in UNSTABLE_PROMPT_SEQUENCE:
|
||||||
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
|
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
|
||||||
SamplingParams(max_tokens=1))
|
SamplingParams(max_tokens=1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
def test_fully_cached_prefill_needs_uncached_token(model):
|
||||||
|
block_size = 16
|
||||||
|
max_num_batched_tokens = 16
|
||||||
|
num_output_tokens = 5
|
||||||
|
# Make a vllm engine
|
||||||
|
runner = VllmRunner(
|
||||||
|
model_name=model,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
block_size=block_size,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_seqs=max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
engine: LLMEngine = runner.model.llm_engine
|
||||||
|
|
||||||
|
scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore
|
||||||
|
engine.scheduler[0] = scheduler
|
||||||
|
|
||||||
|
# SeqA
|
||||||
|
seqA_tokens = list(range(2 * block_size))
|
||||||
|
seqA, seq_groupA = create_dummy_prompt(
|
||||||
|
request_id="0",
|
||||||
|
prompt_tokens=seqA_tokens,
|
||||||
|
max_tokens=num_output_tokens,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.add_seq_group(seq_groupA)
|
||||||
|
|
||||||
|
assert seqA.data.get_num_computed_tokens() == 0
|
||||||
|
|
||||||
|
# Prefill seqA
|
||||||
|
while not seqA.is_finished():
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# seqB
|
||||||
|
seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1
|
||||||
|
seqB, seq_groupB = create_dummy_prompt(
|
||||||
|
request_id="1",
|
||||||
|
prompt_tokens=seqB_tokens,
|
||||||
|
max_tokens=num_output_tokens,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# seqC is the same as seqA
|
||||||
|
seqC, seq_groupC = create_dummy_prompt(
|
||||||
|
request_id="2",
|
||||||
|
prompt_tokens=seqA_tokens,
|
||||||
|
max_tokens=num_output_tokens,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.add_seq_group(seq_groupB)
|
||||||
|
scheduler.add_seq_group(seq_groupC)
|
||||||
|
|
||||||
|
# Even seqC is fully cached, it should not be prefilled since we
|
||||||
|
# require at least 1 uncached token.
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
|
||||||
|
assert len(sched_out.scheduled_seq_groups) == 1
|
||||||
|
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
|
||||||
|
seq_groupB.request_id)
|
||||||
|
assert (sched_out.scheduled_seq_groups[0].token_chunk_size ==
|
||||||
|
max_num_batched_tokens)
|
||||||
|
|
||||||
|
# When seqB is finished, seqC could be prefilled.
|
||||||
|
while not seqB.is_finished():
|
||||||
|
engine.step()
|
||||||
|
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
|
||||||
|
assert len(sched_out.scheduled_seq_groups) == 1
|
||||||
|
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
|
||||||
|
seq_groupB.request_id)
|
||||||
|
|
||||||
|
engine.step()
|
||||||
|
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
|
||||||
|
assert len(sched_out.scheduled_seq_groups) == 1
|
||||||
|
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
|
||||||
|
seq_groupC.request_id)
|
||||||
|
assert sched_out.scheduled_seq_groups[0].token_chunk_size == len(
|
||||||
|
seqA_tokens)
|
||||||
|
@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
device = Device.GPU
|
device = Device.GPU
|
||||||
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||||
|
|
||||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
|
||||||
block_ids: List[int],
|
|
||||||
skip_last_block_id: bool) -> List[int]:
|
|
||||||
# Prefix caching only supported on GPU.
|
|
||||||
device = Device.GPU
|
|
||||||
return self._allocators[device].get_computed_block_ids(
|
|
||||||
prev_computed_block_ids, block_ids, skip_last_block_id)
|
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
# Prefix caching only supported on GPU.
|
# Prefix caching only supported on GPU.
|
||||||
@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
self._swap_mapping.clear()
|
self._swap_mapping.clear()
|
||||||
return list(mapping.items())
|
return list(mapping.items())
|
||||||
|
|
||||||
|
def find_cached_blocks_prefix(
|
||||||
|
self,
|
||||||
|
block_hashes: List[int],
|
||||||
|
device: Device = Device.GPU,
|
||||||
|
) -> List[int]:
|
||||||
|
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
|
||||||
|
|
||||||
|
|
||||||
class NullBlock(Block):
|
class NullBlock(Block):
|
||||||
"""
|
"""
|
||||||
|
@ -159,12 +159,6 @@ class BlockAllocator(ABC):
|
|||||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
|
||||||
block_ids: List[int],
|
|
||||||
skip_last_block_id: bool) -> List[int]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
@ -192,6 +186,13 @@ class BlockAllocator(ABC):
|
|||||||
class NoFreeBlocksError(ValueError):
|
class NoFreeBlocksError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def find_cached_blocks_prefix(
|
||||||
|
self,
|
||||||
|
block_hashes: List[int],
|
||||||
|
) -> List[int]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DeviceAwareBlockAllocator(ABC):
|
class DeviceAwareBlockAllocator(ABC):
|
||||||
|
|
||||||
@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_immutable_blocks(self, prev_block: Optional[Block],
|
def allocate_immutable_blocks(
|
||||||
|
self,
|
||||||
|
prev_block: Optional[Block],
|
||||||
block_token_ids: List[List[int]],
|
block_token_ids: List[List[int]],
|
||||||
device: Device) -> List[Block]:
|
device: Device,
|
||||||
|
) -> List[Block]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
|
|||||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
|
||||||
block_ids: List[int],
|
|
||||||
skip_last_block_id: bool) -> List[int]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
|
|||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def find_cached_blocks_prefix(
|
||||||
|
self,
|
||||||
|
block_hashes: List[int],
|
||||||
|
device: Device = Device.GPU,
|
||||||
|
) -> List[int]:
|
||||||
|
pass
|
||||||
|
@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
|
||||||
block_ids: List[int],
|
|
||||||
skip_last_block_id: bool) -> List[int]:
|
|
||||||
"""No prefix caching here => return empty list
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
"""Determine blocks that can be skipped in prefill.
|
"""Determine blocks that can be skipped in prefill.
|
||||||
@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
def get_prefix_cache_hit_rate(self) -> float:
|
def get_prefix_cache_hit_rate(self) -> float:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||||
|
# Not applicable for naive block allocator.
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class NaiveBlock(Block):
|
class NaiveBlock(Block):
|
||||||
"""An implementation of the Block class that does not support prefix
|
"""An implementation of the Block class that does not support prefix
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
"""Token blocks."""
|
"""Token blocks."""
|
||||||
|
import sys
|
||||||
|
from bisect import bisect_left
|
||||||
from os.path import commonprefix
|
from os.path import commonprefix
|
||||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
|
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
|
||||||
|
Tuple)
|
||||||
|
|
||||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||||
get_all_blocks_recursively)
|
get_all_blocks_recursively)
|
||||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
|
||||||
|
DeviceAwareBlockAllocator)
|
||||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
||||||
NaiveBlockAllocator)
|
NaiveBlockAllocator)
|
||||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||||
|
from vllm.sequence import Sequence
|
||||||
|
|
||||||
PrefixHash = int
|
PrefixHash = int
|
||||||
|
|
||||||
@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
else:
|
else:
|
||||||
return block_id in self.evictor
|
return block_id in self.evictor
|
||||||
|
|
||||||
def get_computed_block_ids(self,
|
|
||||||
prev_computed_block_ids: List[int],
|
|
||||||
block_ids: List[int],
|
|
||||||
skip_last_block_id: bool = True) -> List[int]:
|
|
||||||
prev_prefix_size = len(prev_computed_block_ids)
|
|
||||||
cur_size = len(block_ids)
|
|
||||||
if skip_last_block_id:
|
|
||||||
cur_size -= 1
|
|
||||||
|
|
||||||
# Sanity checks
|
|
||||||
assert cur_size >= 0
|
|
||||||
assert prev_prefix_size <= cur_size
|
|
||||||
|
|
||||||
ret = prev_computed_block_ids
|
|
||||||
for i in range(prev_prefix_size, cur_size):
|
|
||||||
block_id = block_ids[i]
|
|
||||||
if self.block_is_computed(block_id):
|
|
||||||
ret.append(block_id)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
"""Return the block ids that are common for a given sequence group.
|
"""Return the block ids that are common for a given sequence group.
|
||||||
@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
|
|
||||||
block.block_id = block_id # Assign block_id
|
block.block_id = block_id # Assign block_id
|
||||||
|
|
||||||
|
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||||
|
"""
|
||||||
|
Given a list of block hashes, return the prefix of the block hashes that
|
||||||
|
are all cached.
|
||||||
|
|
||||||
|
Since a block's block hash includes the hashes of all previous blocks,
|
||||||
|
and we only allocate/deallocate blocks in the entire sequence, so if a
|
||||||
|
block is cached, then all previous blocks are also cached. With this
|
||||||
|
property, we can use binary search to find the prefix of cached blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_hashes (List[int]): The list of block hashes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: The prefix of the `block_hashes` that are cached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _block_is_cached(block_hash: PrefixHash) -> bool:
|
||||||
|
if block_hash not in self._cached_blocks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cached_block_id = self._cached_blocks[block_hash]
|
||||||
|
# We only consider the blocks that are marked as computed.
|
||||||
|
return self.block_is_computed(cached_block_id)
|
||||||
|
|
||||||
|
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
|
||||||
|
|
||||||
|
# python <= 3.10 don't have the key argument
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
a = [key(e) for e in a]
|
||||||
|
return bisect_left(a, x)
|
||||||
|
else:
|
||||||
|
return bisect_left(a, x, key=key)
|
||||||
|
|
||||||
|
# Look for the first block that's not cached, and returns the prefix
|
||||||
|
# i.e. blocks that are cached.
|
||||||
|
idx = _bisect_left(block_hashes,
|
||||||
|
True,
|
||||||
|
key=lambda x: not _block_is_cached(x))
|
||||||
|
return block_hashes[:idx]
|
||||||
|
|
||||||
|
|
||||||
class PrefixCachingBlock(Block):
|
class PrefixCachingBlock(Block):
|
||||||
"""A block implementation that supports prefix caching.
|
"""A block implementation that supports prefix caching.
|
||||||
@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
|
|||||||
|
|
||||||
|
|
||||||
class ComputedBlocksTracker:
|
class ComputedBlocksTracker:
|
||||||
"""Handles caching of per-sequence computed block ids.
|
"""
|
||||||
When a sequence appears for the first time, it traverses all of the
|
Tracks the computed blocks for each sequence.
|
||||||
blocks and detects the prefix of blocks that is computed. On the
|
|
||||||
subsequent times, it only traverses the new blocks that were added
|
|
||||||
and updates the already recorded prefix of blocks with the newly
|
|
||||||
computed blocks.
|
|
||||||
|
|
||||||
To avoid redundant traversals, the algorithm also detects when there
|
Internally, it maintains a map from sequence id to the list of block hashes
|
||||||
is a "gap" in the computed prefix. For example, if we have blocks =
|
for the sequence. We cache the hashes of the full blocks for each sequence,
|
||||||
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
|
and make sure the hash is calculated in the same way as the allocator.
|
||||||
we won't try to add more computed blocks to [1,2,3] in this sequence
|
When a sequence is being decoded, we also update the sequence's hash
|
||||||
iteration, and will add more computed blocks only after the sequence is
|
accordingly and incrementally.
|
||||||
freed and reused again.
|
|
||||||
|
|
||||||
Note that currently, for a given sequence, we also skip the last
|
From the sequence hash, with prefix caching enabled, we could also calculate
|
||||||
block id for caching purposes, to avoid caching of a full sequence
|
the number of cached tokens for the sequence by looking up the number of
|
||||||
|
cached block hashes in the allocator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, allocator):
|
def __init__(
|
||||||
|
self,
|
||||||
|
allocator: DeviceAwareBlockAllocator,
|
||||||
|
block_size: int,
|
||||||
|
enable_caching: bool,
|
||||||
|
):
|
||||||
self._allocator = allocator
|
self._allocator = allocator
|
||||||
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
|
self._block_size = block_size
|
||||||
bool]] = {}
|
self._enable_caching = enable_caching
|
||||||
|
|
||||||
def add_seq(self, seq_id: int) -> None:
|
# A map from seq_id to the list of block hashes for the
|
||||||
"""Start tracking seq_id
|
# sequence. This is so that we don't have to recompute the block hashes
|
||||||
"""
|
# for the sequence when we need to check if the sequence is cached.
|
||||||
assert seq_id not in self._cached_computed_seq_blocks
|
# Note a block that's not full will not have its hash calculated and
|
||||||
self._cached_computed_seq_blocks[seq_id] = ([], False)
|
# recorded.
|
||||||
|
self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
|
# A map from seq_id to the number of tokens that are cached for the
|
||||||
|
# sequence.
|
||||||
|
# We need this so that a sequence in continuous prefill doesn't
|
||||||
|
# accidentally see its cached token count change. See comments in
|
||||||
|
# `get_num_cached_tokens` for more details.
|
||||||
|
self._seq_id_to_num_tokens_computed: Dict[int, int] = {}
|
||||||
|
|
||||||
|
def _update_seq_hashes(self, seq: Sequence) -> None:
|
||||||
|
"""Incrementally update the sequence's block hashes and record them."""
|
||||||
|
assert self._enable_caching
|
||||||
|
|
||||||
|
block_hashes_recorded = self._seq_id_to_blocks_hashes.get(
|
||||||
|
seq.seq_id, [])
|
||||||
|
cur_num_blocks_recorded = len(block_hashes_recorded)
|
||||||
|
token_ids = seq.get_token_ids()
|
||||||
|
assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, (
|
||||||
|
f"The sequence has {len(token_ids)} tokens, but"
|
||||||
|
f" already recorded {cur_num_blocks_recorded} blocks. "
|
||||||
|
"This should not happen since we assume blocks are "
|
||||||
|
"only appended other than recomputation. When the sequence is "
|
||||||
|
"recomputed, we should have removed the info of the old blocks.")
|
||||||
|
# Update the computed block hashes for the sequence. Since only full
|
||||||
|
# blocks are considered as "computed", we take floor here.
|
||||||
|
num_computed_blocks = len(token_ids) // self._block_size
|
||||||
|
|
||||||
|
# We need to know the hash of the previous block to compute the hash of
|
||||||
|
# the current block so that blocks could be uniquely identified across
|
||||||
|
# sequences of prefixes.
|
||||||
|
prev_block_hash = (None if cur_num_blocks_recorded == 0 else
|
||||||
|
block_hashes_recorded[-1])
|
||||||
|
# Only update the computed block hashes for the new blocks
|
||||||
|
for i in range(cur_num_blocks_recorded, num_computed_blocks):
|
||||||
|
assert len(token_ids) >= (i + 1) * self._block_size
|
||||||
|
block_token_ids = token_ids[i * self._block_size:(i + 1) *
|
||||||
|
self._block_size]
|
||||||
|
# This has to be kept in sync with the allocator's hash
|
||||||
|
# calculation.
|
||||||
|
block_hash = PrefixCachingBlock.hash_block_tokens(
|
||||||
|
is_first_block=prev_block_hash is None,
|
||||||
|
prev_block_hash=prev_block_hash,
|
||||||
|
cur_block_token_ids=block_token_ids,
|
||||||
|
)
|
||||||
|
block_hashes_recorded.append(block_hash)
|
||||||
|
prev_block_hash = block_hash
|
||||||
|
|
||||||
|
self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded
|
||||||
|
|
||||||
|
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||||
|
if not self._enable_caching:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# We always try to update the sequence hashes on the fly.
|
||||||
|
# This is to ensure that we don't miss any cached tokens for the
|
||||||
|
# sequence during decode.
|
||||||
|
# This routine should only update hash for any new blocks too.
|
||||||
|
self._update_seq_hashes(seq)
|
||||||
|
|
||||||
|
num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get(
|
||||||
|
seq.seq_id, None)
|
||||||
|
|
||||||
|
# TODO(rickyx): This hack could be removed once we mark blocks as
|
||||||
|
# computed correctly with chunked prefills.
|
||||||
|
if num_computed_tokens_prev is not None and seq.is_prefill():
|
||||||
|
# For a sequence that is still in prefill, we don't
|
||||||
|
# recompute the number of cached tokens.
|
||||||
|
# This also handles correctly chunked prefill since currently
|
||||||
|
# we mark blocks as computed even if the sequence is still partially
|
||||||
|
# prefilled. So a continuously prefilled sequence should not
|
||||||
|
# see its cached token count change while running.
|
||||||
|
return num_computed_tokens_prev
|
||||||
|
|
||||||
|
block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id]
|
||||||
|
|
||||||
|
# This is O(logN), where N is the number of blocks.
|
||||||
|
num_cached_blocks = len(
|
||||||
|
self._allocator.find_cached_blocks_prefix(block_hashes))
|
||||||
|
num_cached_tokens = num_cached_blocks * self._block_size
|
||||||
|
self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens
|
||||||
|
return num_cached_tokens
|
||||||
|
|
||||||
def remove_seq(self, seq_id: int) -> None:
|
def remove_seq(self, seq_id: int) -> None:
|
||||||
"""Stop tracking seq_id
|
"""Stop tracking the sequence."""
|
||||||
"""
|
if not self._enable_caching:
|
||||||
assert seq_id in self._cached_computed_seq_blocks
|
return
|
||||||
del self._cached_computed_seq_blocks[seq_id]
|
assert seq_id in self._seq_id_to_blocks_hashes
|
||||||
|
del self._seq_id_to_blocks_hashes[seq_id]
|
||||||
|
|
||||||
def get_cached_computed_blocks_and_update(
|
assert seq_id in self._seq_id_to_num_tokens_computed
|
||||||
self, seq_id: int, block_ids: List[int]) -> List[int]:
|
del self._seq_id_to_num_tokens_computed[seq_id]
|
||||||
""" Look at the class documentation for details
|
|
||||||
"""
|
|
||||||
# Ensure seq_id is already tracked
|
|
||||||
assert seq_id in self._cached_computed_seq_blocks
|
|
||||||
|
|
||||||
# Get cached data (may be empty on the first time)
|
|
||||||
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
|
|
||||||
seq_id]
|
|
||||||
|
|
||||||
if has_gap:
|
|
||||||
# When gap is detected, we do not add more computed blocks at this
|
|
||||||
# sequence iteration
|
|
||||||
return prev_computed_block_ids
|
|
||||||
|
|
||||||
# We do not consider the last block id for caching purposes.
|
|
||||||
num_cur_blocks = len(block_ids) - 1
|
|
||||||
assert num_cur_blocks >= 0
|
|
||||||
|
|
||||||
if len(prev_computed_block_ids) >= num_cur_blocks:
|
|
||||||
# Cache HIT
|
|
||||||
assert len(prev_computed_block_ids) == num_cur_blocks
|
|
||||||
return prev_computed_block_ids
|
|
||||||
|
|
||||||
# If here, then we may possibly add more computed blocks. As a result,
|
|
||||||
# traverse the additional blocks after prev_computed_block_ids to
|
|
||||||
# detect more computed blocks and add them.
|
|
||||||
|
|
||||||
# Incremental init for seq_id => Look only at the new blocks
|
|
||||||
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
|
|
||||||
prev_computed_block_ids,
|
|
||||||
block_ids,
|
|
||||||
skip_last_block_id=
|
|
||||||
True, # We skip last block id to avoid caching of full seq
|
|
||||||
)
|
|
||||||
|
|
||||||
# Detect if there is a "gap"
|
|
||||||
has_gap = len(computed_block_ids) < num_cur_blocks
|
|
||||||
|
|
||||||
# Record
|
|
||||||
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
|
|
||||||
has_gap)
|
|
||||||
|
|
||||||
return computed_block_ids
|
|
||||||
|
|
||||||
|
|
||||||
class LastAccessBlocksTracker:
|
class LastAccessBlocksTracker:
|
||||||
|
@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
||||||
|
|
||||||
self._computed_blocks_tracker = ComputedBlocksTracker(
|
self._computed_blocks_tracker = ComputedBlocksTracker(
|
||||||
self.block_allocator)
|
self.block_allocator, self.block_size, self.enable_caching)
|
||||||
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
||||||
self.block_allocator)
|
self.block_allocator)
|
||||||
|
|
||||||
@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
self.block_tables[seq.seq_id] = block_table
|
self.block_tables[seq.seq_id] = block_table
|
||||||
|
|
||||||
# Track seq
|
# Track seq
|
||||||
self._computed_blocks_tracker.add_seq(seq.seq_id)
|
|
||||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||||
|
|
||||||
# Assign the block table for each sequence.
|
# Assign the block table for each sequence.
|
||||||
@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
self.block_tables[seq.seq_id] = block_table.fork()
|
self.block_tables[seq.seq_id] = block_table.fork()
|
||||||
|
|
||||||
# Track seq
|
# Track seq
|
||||||
self._computed_blocks_tracker.add_seq(seq.seq_id)
|
|
||||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||||
|
|
||||||
# Allocate cross-attention block table for encoder sequence
|
# Allocate cross-attention block table for encoder sequence
|
||||||
@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
"""
|
"""
|
||||||
computed_seq_block_ids = []
|
computed_seq_block_ids = []
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
computed_seq_block_ids.append(
|
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
|
||||||
self._computed_blocks_tracker.
|
num_cached_tokens = (
|
||||||
get_cached_computed_blocks_and_update(
|
self._computed_blocks_tracker.get_num_cached_tokens(seq))
|
||||||
seq.seq_id,
|
assert num_cached_tokens % self.block_size == 0
|
||||||
self.block_tables[seq.seq_id].physical_block_ids))
|
num_cached_blocks = num_cached_tokens // self.block_size
|
||||||
|
computed_block_ids = all_blocks[:num_cached_blocks]
|
||||||
|
computed_seq_block_ids.append(computed_block_ids)
|
||||||
|
|
||||||
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
||||||
return self.block_allocator.get_common_computed_block_ids(
|
return self.block_allocator.get_common_computed_block_ids(
|
||||||
@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||||
|
|
||||||
# Track child seq
|
# Track child seq
|
||||||
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
|
|
||||||
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
||||||
|
|
||||||
def can_swap_in(self, seq_group: SequenceGroup,
|
def can_swap_in(self, seq_group: SequenceGroup,
|
||||||
@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
return AllocStatus.OK
|
return AllocStatus.OK
|
||||||
else:
|
else:
|
||||||
return AllocStatus.LATER
|
return AllocStatus.LATER
|
||||||
|
|
||||||
|
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||||
|
"""Get the number of tokens in blocks that are already computed and
|
||||||
|
cached in the block manager for the sequence.
|
||||||
|
"""
|
||||||
|
return self._computed_blocks_tracker.get_num_cached_tokens(seq)
|
||||||
|
@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
|
|||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||||
|
pass
|
||||||
|
@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
|||||||
|
|
||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||||
|
return 0
|
||||||
|
@ -56,11 +56,16 @@ class SchedulingBudget:
|
|||||||
max_num_seqs: int
|
max_num_seqs: int
|
||||||
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
||||||
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
|
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
|
||||||
|
# Number of cached tokens in the batch.
|
||||||
|
_num_cached_tokens: int = 0
|
||||||
|
# Number of actual non-cached tokens in the batch.
|
||||||
_num_batched_tokens: int = 0
|
_num_batched_tokens: int = 0
|
||||||
_num_curr_seqs: int = 0
|
_num_curr_seqs: int = 0
|
||||||
|
|
||||||
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
|
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
|
||||||
assert num_new_tokens != 0
|
# We allow num_new_tokens to be 0 when the entire sequence has
|
||||||
|
# been cached.
|
||||||
|
assert num_new_tokens >= 0
|
||||||
assert num_new_seqs != 0
|
assert num_new_seqs != 0
|
||||||
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
|
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
|
||||||
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
|
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
|
||||||
@ -68,12 +73,18 @@ class SchedulingBudget:
|
|||||||
def remaining_token_budget(self):
|
def remaining_token_budget(self):
|
||||||
return self.token_budget - self.num_batched_tokens
|
return self.token_budget - self.num_batched_tokens
|
||||||
|
|
||||||
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
|
def add_num_batched_tokens(self,
|
||||||
|
req_id: str,
|
||||||
|
num_batched_tokens: int,
|
||||||
|
num_cached_tokens: int = 0):
|
||||||
if req_id in self._request_ids_num_batched_tokens:
|
if req_id in self._request_ids_num_batched_tokens:
|
||||||
return
|
return
|
||||||
|
assert num_cached_tokens >= 0
|
||||||
|
assert num_batched_tokens >= 0
|
||||||
|
|
||||||
self._request_ids_num_batched_tokens.add(req_id)
|
self._request_ids_num_batched_tokens.add(req_id)
|
||||||
self._num_batched_tokens += num_batched_tokens
|
self._num_batched_tokens += num_batched_tokens
|
||||||
|
self._num_cached_tokens += num_cached_tokens
|
||||||
|
|
||||||
def subtract_num_batched_tokens(self, req_id: str,
|
def subtract_num_batched_tokens(self, req_id: str,
|
||||||
num_batched_tokens: int):
|
num_batched_tokens: int):
|
||||||
@ -101,6 +112,10 @@ class SchedulingBudget:
|
|||||||
def num_curr_seqs(self):
|
def num_curr_seqs(self):
|
||||||
return self._num_curr_seqs
|
return self._num_curr_seqs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cached_tokens(self):
|
||||||
|
return self._num_cached_tokens
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScheduledSequenceGroup:
|
class ScheduledSequenceGroup:
|
||||||
@ -541,9 +556,19 @@ class Scheduler:
|
|||||||
assert len(self._async_stopped) == 0
|
assert len(self._async_stopped) == 0
|
||||||
while running_queue:
|
while running_queue:
|
||||||
seq_group = running_queue[0]
|
seq_group = running_queue[0]
|
||||||
num_running_tokens = self._get_num_new_tokens(
|
# We discard the cached tokens info here because we don't need it
|
||||||
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
# for running sequence:
|
||||||
|
# 1. If a sequence is running with chunked prefill, the cached
|
||||||
|
# tokens info was already used for the first prefill.
|
||||||
|
# 2. If a sequence is running with non-chunked prefill, then
|
||||||
|
# there it's a decoding sequence, and the cached tokens info is
|
||||||
|
# irrelevant.
|
||||||
|
num_uncached_new_tokens, _ = (
|
||||||
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
|
seq_group, SequenceStatus.RUNNING, enable_chunking,
|
||||||
|
budget))
|
||||||
|
|
||||||
|
num_running_tokens = num_uncached_new_tokens
|
||||||
if num_running_tokens == 0:
|
if num_running_tokens == 0:
|
||||||
# No budget => Stop
|
# No budget => Stop
|
||||||
break
|
break
|
||||||
@ -715,13 +740,15 @@ class Scheduler:
|
|||||||
# The total number of sequences in the RUNNING state should not
|
# The total number of sequences in the RUNNING state should not
|
||||||
# exceed the maximum number of sequences.
|
# exceed the maximum number of sequences.
|
||||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
num_new_tokens = self._get_num_new_tokens(seq_group,
|
num_new_tokens_uncached, num_new_tokens_cached = (
|
||||||
SequenceStatus.SWAPPED,
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
enable_chunking, budget)
|
seq_group, SequenceStatus.SWAPPED, enable_chunking,
|
||||||
|
budget))
|
||||||
|
|
||||||
if (num_new_tokens == 0
|
if num_new_tokens_uncached == 0 or not budget.can_schedule(
|
||||||
or not budget.can_schedule(num_new_tokens=num_new_tokens,
|
num_new_tokens=num_new_tokens_uncached,
|
||||||
num_new_seqs=num_new_seqs)):
|
num_new_seqs=num_new_seqs,
|
||||||
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
if lora_int_id > 0 and curr_loras is not None:
|
if lora_int_id > 0 and curr_loras is not None:
|
||||||
@ -732,12 +759,19 @@ class Scheduler:
|
|||||||
is_prefill = seq_group.is_prefill()
|
is_prefill = seq_group.is_prefill()
|
||||||
if is_prefill:
|
if is_prefill:
|
||||||
prefill_seq_groups.append(
|
prefill_seq_groups.append(
|
||||||
ScheduledSequenceGroup(seq_group,
|
ScheduledSequenceGroup(
|
||||||
token_chunk_size=num_new_tokens))
|
seq_group,
|
||||||
|
token_chunk_size=num_new_tokens_uncached +
|
||||||
|
num_new_tokens_cached,
|
||||||
|
))
|
||||||
else:
|
else:
|
||||||
decode_seq_groups.append(
|
decode_seq_groups.append(
|
||||||
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
|
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
|
||||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
budget.add_num_batched_tokens(
|
||||||
|
seq_group.request_id,
|
||||||
|
num_batched_tokens=num_new_tokens_uncached,
|
||||||
|
num_cached_tokens=num_new_tokens_cached,
|
||||||
|
)
|
||||||
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
||||||
|
|
||||||
swapped_queue.extendleft(leftover_swapped)
|
swapped_queue.extendleft(leftover_swapped)
|
||||||
@ -803,26 +837,30 @@ class Scheduler:
|
|||||||
if waiting_queue:
|
if waiting_queue:
|
||||||
seq_group = waiting_queue.popleft()
|
seq_group = waiting_queue.popleft()
|
||||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
num_new_tokens = self._get_num_new_tokens(seq_group,
|
num_new_tokens_uncached, _ = (
|
||||||
SequenceStatus.WAITING,
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
False, budget)
|
seq_group, SequenceStatus.WAITING, False, budget))
|
||||||
|
|
||||||
#Only preempt if priority inversion exists
|
#Only preempt if priority inversion exists
|
||||||
while running_queue and self._get_priority(
|
while running_queue and self._get_priority(
|
||||||
running_queue[-1]) > self._get_priority(seq_group):
|
running_queue[-1]) > self._get_priority(seq_group):
|
||||||
#Only preempt if waiting sequence cannot be allocated
|
#Only preempt if waiting sequence cannot be allocated
|
||||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||||
if (num_new_tokens and can_allocate == AllocStatus.OK
|
if (num_new_tokens_uncached > 0
|
||||||
and budget.can_schedule(num_new_tokens=num_new_tokens,
|
and can_allocate == AllocStatus.OK
|
||||||
num_new_seqs=num_new_seqs)):
|
and budget.can_schedule(
|
||||||
|
num_new_tokens=num_new_tokens_uncached,
|
||||||
|
num_new_seqs=num_new_seqs,
|
||||||
|
)):
|
||||||
break
|
break
|
||||||
|
|
||||||
#Adjust budget to remove the victim sequence group
|
#Adjust budget to remove the victim sequence group
|
||||||
vseq_group = running_queue.pop()
|
vseq_group = running_queue.pop()
|
||||||
num_running_tokens = self._get_num_new_tokens(
|
num_running_tokens_uncached, _ = (
|
||||||
vseq_group, SequenceStatus.RUNNING, False, budget)
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
budget.subtract_num_batched_tokens(vseq_group.request_id,
|
vseq_group, SequenceStatus.RUNNING, False, budget))
|
||||||
num_running_tokens)
|
budget.subtract_num_batched_tokens(
|
||||||
|
vseq_group.request_id, num_running_tokens_uncached)
|
||||||
num_running_seqs = vseq_group.get_max_num_running_seqs()
|
num_running_seqs = vseq_group.get_max_num_running_seqs()
|
||||||
budget.subtract_num_seqs(vseq_group.request_id,
|
budget.subtract_num_seqs(vseq_group.request_id,
|
||||||
num_running_seqs)
|
num_running_seqs)
|
||||||
@ -882,9 +920,12 @@ class Scheduler:
|
|||||||
assert len(waiting_seqs) == 1, (
|
assert len(waiting_seqs) == 1, (
|
||||||
"Waiting sequence group should have only one prompt "
|
"Waiting sequence group should have only one prompt "
|
||||||
"sequence.")
|
"sequence.")
|
||||||
num_new_tokens = self._get_num_new_tokens(seq_group,
|
num_new_tokens_uncached, num_new_tokens_cached = (
|
||||||
SequenceStatus.WAITING,
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
enable_chunking, budget)
|
seq_group, SequenceStatus.WAITING, enable_chunking,
|
||||||
|
budget))
|
||||||
|
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
|
||||||
|
|
||||||
if not enable_chunking:
|
if not enable_chunking:
|
||||||
num_prompt_tokens = waiting_seqs[0].get_len()
|
num_prompt_tokens = waiting_seqs[0].get_len()
|
||||||
assert num_new_tokens == num_prompt_tokens
|
assert num_new_tokens == num_prompt_tokens
|
||||||
@ -935,10 +976,18 @@ class Scheduler:
|
|||||||
waiting_queue.popleft()
|
waiting_queue.popleft()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if (budget.num_batched_tokens >=
|
||||||
|
self.scheduler_config.max_num_batched_tokens):
|
||||||
|
# We've reached the budget limit - since there might be
|
||||||
|
# continuous prefills in the running queue, we should break
|
||||||
|
# to avoid scheduling any new prefills.
|
||||||
|
break
|
||||||
|
|
||||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
if (num_new_tokens == 0
|
if num_new_tokens_uncached == 0 or not budget.can_schedule(
|
||||||
or not budget.can_schedule(num_new_tokens=num_new_tokens,
|
num_new_tokens=num_new_tokens_uncached,
|
||||||
num_new_seqs=num_new_seqs)):
|
num_new_seqs=num_new_seqs,
|
||||||
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Can schedule this request.
|
# Can schedule this request.
|
||||||
@ -967,7 +1016,11 @@ class Scheduler:
|
|||||||
seq_groups.append(
|
seq_groups.append(
|
||||||
ScheduledSequenceGroup(seq_group=seq_group,
|
ScheduledSequenceGroup(seq_group=seq_group,
|
||||||
token_chunk_size=num_new_tokens))
|
token_chunk_size=num_new_tokens))
|
||||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
budget.add_num_batched_tokens(
|
||||||
|
seq_group.request_id,
|
||||||
|
num_batched_tokens=num_new_tokens_uncached,
|
||||||
|
num_cached_tokens=num_new_tokens_cached,
|
||||||
|
)
|
||||||
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
||||||
|
|
||||||
# Queue requests that couldn't be scheduled.
|
# Queue requests that couldn't be scheduled.
|
||||||
@ -1075,7 +1128,8 @@ class Scheduler:
|
|||||||
return SchedulerOutputs(
|
return SchedulerOutputs(
|
||||||
scheduled_seq_groups=scheduled_seq_groups,
|
scheduled_seq_groups=scheduled_seq_groups,
|
||||||
num_prefill_groups=num_prefill_groups,
|
num_prefill_groups=num_prefill_groups,
|
||||||
num_batched_tokens=budget.num_batched_tokens,
|
num_batched_tokens=budget.num_batched_tokens +
|
||||||
|
budget.num_cached_tokens,
|
||||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
@ -1119,7 +1173,6 @@ class Scheduler:
|
|||||||
running_scheduled.swapped_out) == 0:
|
running_scheduled.swapped_out) == 0:
|
||||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||||
|
|
||||||
# Schedule new prefills.
|
|
||||||
prefills = self._schedule_prefills(budget,
|
prefills = self._schedule_prefills(budget,
|
||||||
curr_loras,
|
curr_loras,
|
||||||
enable_chunking=True)
|
enable_chunking=True)
|
||||||
@ -1157,7 +1210,8 @@ class Scheduler:
|
|||||||
num_prefill_groups=(len(prefills.seq_groups) +
|
num_prefill_groups=(len(prefills.seq_groups) +
|
||||||
len(swapped_in.prefill_seq_groups) +
|
len(swapped_in.prefill_seq_groups) +
|
||||||
len(running_scheduled.prefill_seq_groups)),
|
len(running_scheduled.prefill_seq_groups)),
|
||||||
num_batched_tokens=budget.num_batched_tokens,
|
num_batched_tokens=budget.num_batched_tokens +
|
||||||
|
budget.num_cached_tokens,
|
||||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||||
blocks_to_copy=running_scheduled.blocks_to_copy +
|
blocks_to_copy=running_scheduled.blocks_to_copy +
|
||||||
@ -1584,30 +1638,140 @@ class Scheduler:
|
|||||||
|
|
||||||
return self.scheduler_config.num_lookahead_slots
|
return self.scheduler_config.num_lookahead_slots
|
||||||
|
|
||||||
def _get_num_new_tokens(self, seq_group: SequenceGroup,
|
def _get_num_new_uncached_and_cached_tokens(
|
||||||
status: SequenceStatus, enable_chunking: bool,
|
self,
|
||||||
budget: SchedulingBudget) -> int:
|
seq_group: SequenceGroup,
|
||||||
"""Get the next new tokens to compute for a given sequence group
|
status: SequenceStatus,
|
||||||
that's in a given `status`.
|
enable_chunking: bool,
|
||||||
|
budget: SchedulingBudget,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the number of new uncached and cached tokens to schedule for a
|
||||||
|
given sequence group that's in a given `status`.
|
||||||
|
|
||||||
The API could chunk the number of tokens to compute based on `budget`
|
The API could chunk the number of tokens to compute based on `budget`
|
||||||
if `enable_chunking` is True. If a sequence group has multiple
|
if `enable_chunking` is True. If a sequence group has multiple
|
||||||
sequences (e.g., running beam search), it means it is in decoding
|
sequences (e.g., running beam search), it means it is in decoding
|
||||||
phase, so chunking doesn't happen.
|
phase, so chunking doesn't happen.
|
||||||
|
|
||||||
Returns 0 if the new token cannot be computed due to token budget.
|
Returns (0, 0) if the new token cannot be computed due to token budget.
|
||||||
|
|
||||||
|
The cached tokens's blocks are already computed, and the attention
|
||||||
|
backend will reuse the cached blocks rather than recomputing them. So
|
||||||
|
the scheduler could schedule these cached tokens "for free".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_group: The sequence group to get the number of new tokens to
|
||||||
|
schedule.
|
||||||
|
status: The status of the sequences to get the number of new tokens
|
||||||
|
to schedule.
|
||||||
|
enable_chunking: Whether to chunk the number of tokens to compute.
|
||||||
|
budget: The budget to chunk the number of tokens to compute.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of two ints. The first int is the number of new uncached
|
||||||
|
tokens to schedule. The second int is the number of cached tokens.
|
||||||
|
If no more new tokens can be scheduled, returns (0, 0).
|
||||||
"""
|
"""
|
||||||
num_new_tokens = 0
|
num_cached_new_tokens = 0
|
||||||
|
num_uncached_new_tokens = 0
|
||||||
|
|
||||||
seqs = seq_group.get_seqs(status=status)
|
seqs = seq_group.get_seqs(status=status)
|
||||||
|
# Compute the number of new uncached and cached tokens for
|
||||||
|
# each sequence.
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
num_new_tokens += seq.get_num_new_tokens()
|
if not seq.is_prefill():
|
||||||
assert num_new_tokens > 0
|
# Decode sequences should always just have 1 uncached token
|
||||||
|
# TODO(rickyx): Actually is this still correct for multi-step?
|
||||||
|
num_uncached_new_tokens += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_computed_tokens_seq = seq.get_num_computed_tokens()
|
||||||
|
all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
|
||||||
|
if not self.cache_config.enable_prefix_caching:
|
||||||
|
# If prefix caching is not enabled, all new tokens are uncached.
|
||||||
|
num_uncached_new_tokens += all_num_new_tokens_seq
|
||||||
|
continue
|
||||||
|
|
||||||
|
# NOTE: the cache token might be currently in a block that's in an
|
||||||
|
# evictor meaning that it's not yet allocated. However, we don't
|
||||||
|
# exclude such tokens in the cache count because it will be
|
||||||
|
# guaranteed to be allocated later if the sequence can be allocated.
|
||||||
|
num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
|
||||||
|
seq)
|
||||||
|
|
||||||
|
# Sanity check.
|
||||||
|
if num_cached_tokens_seq < num_computed_tokens_seq:
|
||||||
|
# This should only happen with chunked prefill, and
|
||||||
|
# the seq is still in prefill. The `num_cached_tokens_seq`
|
||||||
|
# is the value we calculated on scheduling the first prefill.
|
||||||
|
# For subsequent continuous prefill steps, we cached the
|
||||||
|
# number of cache tokens for the sequence so the cached token
|
||||||
|
# count could be less than the number of computed tokens.
|
||||||
|
# See comments on `ComputedBlocksTracker` for more details.
|
||||||
|
assert (
|
||||||
|
seq.is_prefill() and seq.status == SequenceStatus.RUNNING
|
||||||
|
and self.scheduler_config.chunked_prefill_enabled
|
||||||
|
), ("Number of cached tokens should not be less than the "
|
||||||
|
"number of computed tokens for a sequence that's still "
|
||||||
|
f"in prefill. But there are {num_cached_tokens_seq} cached "
|
||||||
|
f"tokens and {num_computed_tokens_seq} computed tokens "
|
||||||
|
f"for sequence {seq.seq_id}.")
|
||||||
|
|
||||||
|
num_cached_new_tokens_seq = max(
|
||||||
|
0, num_cached_tokens_seq - num_computed_tokens_seq)
|
||||||
|
num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
|
||||||
|
num_cached_new_tokens_seq)
|
||||||
|
|
||||||
|
num_uncached_new_tokens += num_uncached_new_tokens_seq
|
||||||
|
num_cached_new_tokens += num_cached_new_tokens_seq
|
||||||
|
|
||||||
|
if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
|
||||||
|
# For a fully cached hit sequence, we actually need to recompute the
|
||||||
|
# last token. So we need at least 1 uncached token to schedule.
|
||||||
|
# See ModelRunner._compute_for_prefix_cache_hit for more details.
|
||||||
|
num_uncached_new_tokens = 1
|
||||||
|
num_cached_new_tokens -= 1
|
||||||
|
|
||||||
|
if enable_chunking and len(seqs) == 1:
|
||||||
# Chunk if a running request cannot fit in the given budget.
|
# Chunk if a running request cannot fit in the given budget.
|
||||||
# If number of seq > 1, it means it is doing beam search
|
# If number of seq > 1, it means it is doing beam search
|
||||||
# in a decode phase. Do not chunk.
|
# in a decode phase. Do not chunk.
|
||||||
if enable_chunking and len(seqs) == 1:
|
num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
|
||||||
|
self.scheduler_config,
|
||||||
|
self.cache_config,
|
||||||
|
budget,
|
||||||
|
self._get_prompt_limit(seq_group),
|
||||||
|
num_uncached_new_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return num_uncached_new_tokens, num_cached_new_tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _chunk_new_tokens_to_schedule(
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
budget: SchedulingBudget,
|
||||||
|
prompt_limit: int,
|
||||||
|
num_new_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Chunks the number of new tokens to schedule based on the budget when
|
||||||
|
chunked prefill is enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_config: The scheduler config.
|
||||||
|
cache_config: The cache config.
|
||||||
|
budget: The budget to chunk the number of tokens to compute.
|
||||||
|
prompt_limit: The maximum number of tokens allowed in a prompt.
|
||||||
|
num_new_tokens: The number of new tokens to schedule.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of new tokens to schedule after chunking.
|
||||||
|
"""
|
||||||
remaining_token_budget = budget.remaining_token_budget()
|
remaining_token_budget = budget.remaining_token_budget()
|
||||||
if self.scheduler_config.is_multi_step:
|
if scheduler_config.is_multi_step:
|
||||||
# The current multi-step + chunked prefill capability does
|
# The current multi-step + chunked prefill capability does
|
||||||
# not actually support chunking prompts.
|
# not actually support chunking prompts.
|
||||||
#
|
#
|
||||||
@ -1617,20 +1781,23 @@ class Scheduler:
|
|||||||
#
|
#
|
||||||
# Prompts with more tokens than the current remaining budget
|
# Prompts with more tokens than the current remaining budget
|
||||||
# are postponed to future scheduler steps
|
# are postponed to future scheduler steps
|
||||||
if num_new_tokens > self._get_prompt_limit(seq_group):
|
if num_new_tokens > prompt_limit:
|
||||||
# If the seq_group is in prompt-stage, pass the
|
# If the seq_group is in prompt-stage, pass the
|
||||||
# num_new_tokens as-is so the caller can ignore
|
# num_new_tokens as-is so the caller can ignore
|
||||||
# the sequence.
|
# the sequence.
|
||||||
pass
|
return num_new_tokens
|
||||||
else:
|
|
||||||
num_new_tokens = 0 \
|
return (0 if num_new_tokens > remaining_token_budget else
|
||||||
if num_new_tokens > remaining_token_budget \
|
num_new_tokens)
|
||||||
else num_new_tokens
|
|
||||||
elif self.cache_config.enable_prefix_caching:
|
if cache_config.enable_prefix_caching:
|
||||||
|
# Adjust the remaining token budget to be divisible by the block
|
||||||
|
# size when prefix caching is enabled.
|
||||||
|
|
||||||
# When prefix caching is enabled, we always allocate
|
# When prefix caching is enabled, we always allocate
|
||||||
# the number of new tokens that is dividable by the block
|
# the number of new tokens that is dividable by the block
|
||||||
# size to avoid partial block matching.
|
# size to avoid partial block matching.
|
||||||
block_size = self.cache_config.block_size
|
block_size = cache_config.block_size
|
||||||
remainder = budget.token_budget % block_size
|
remainder = budget.token_budget % block_size
|
||||||
if remainder != 0:
|
if remainder != 0:
|
||||||
raise ValueError("When enabling chunked prefill and "
|
raise ValueError("When enabling chunked prefill and "
|
||||||
@ -1639,9 +1806,10 @@ class Scheduler:
|
|||||||
"block size, but got chunk_size "
|
"block size, but got chunk_size "
|
||||||
f"({budget.token_budget}) % block_size "
|
f"({budget.token_budget}) % block_size "
|
||||||
f"({block_size}) = {remainder}")
|
f"({block_size}) = {remainder}")
|
||||||
if remaining_token_budget < num_new_tokens:
|
# Round down to block size.
|
||||||
num_new_tokens = (remaining_token_budget //
|
remaining_token_budget = (remaining_token_budget // block_size *
|
||||||
block_size) * block_size
|
block_size)
|
||||||
else:
|
|
||||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
||||||
|
|
||||||
return num_new_tokens
|
return num_new_tokens
|
||||||
|
@ -579,6 +579,9 @@ class Sequence:
|
|||||||
return 1
|
return 1
|
||||||
return self.data.get_num_uncomputed_tokens()
|
return self.data.get_num_uncomputed_tokens()
|
||||||
|
|
||||||
|
def get_num_computed_tokens(self) -> int:
|
||||||
|
return self.data.get_num_computed_tokens()
|
||||||
|
|
||||||
def is_prefill(self) -> bool:
|
def is_prefill(self) -> bool:
|
||||||
return self.data.stage == SequenceStage.PREFILL
|
return self.data.stage == SequenceStage.PREFILL
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user