Prefix Cache Aware Scheduling [1/n] (#10128)

Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
Ricky Xu 2024-11-22 21:15:55 -08:00 committed by GitHub
parent 7c25fe45a6
commit 4634a89d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 962 additions and 236 deletions

View File

@ -5,9 +5,14 @@ from unittest.mock import MagicMock
import pytest
from tests.core.utils import create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
PrefixCachingBlock,
PrefixCachingBlockAllocator)
from vllm.sequence import Logprob
from vllm.utils import Device
class TestPrefixCachingBlock:
@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
token_ids=common_token_ids,
allocator=allocator,
)
block_ids = [block.block_id for block in blocks]
block_hashes = [block.content_hash for block in blocks]
# The allocated blocks should be marked as touched
# but not computed.
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
computed_block_ids = allocator.find_cached_blocks_prefix(
block_hashes)
assert len(computed_block_ids) == 0
allocator.mark_blocks_as_computed([])
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
computed_block_ids = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes)
assert len(computed_block_ids) == common_blocks
@staticmethod
def test_find_cached_blocks_prefix():
"""
This test verifies the behavior of find_cached_blocks_prefix.
"""
block_size = 4
num_blocks = 8
total_test_blocks = 12
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(total_test_blocks * block_size))
block_tokens_seq1 = token_ids[:num_blocks * block_size]
blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq1,
allocator=allocator,
)
block_hashes_seq1 = [block.content_hash for block in blocks_seq1]
allocator.mark_blocks_as_computed([])
# All blocks should be cached.
cached_blocks_seq1 = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks_seq1) == num_blocks
# Free the first sequence.
for block in blocks_seq1:
allocator.free(block)
# All blocks should be still be cached if not required to be allocated.
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == num_blocks
block_tokens_seq2 = token_ids[num_blocks * block_size:]
blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq2,
allocator=allocator,
)
block_hashes_seq2 = [block.content_hash for block in blocks_seq2]
allocator.mark_blocks_as_computed([])
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq2)
assert len(cached_blocks) == len(blocks_seq2)
# Half of the blocks from seq1 should still be cached.
num_evicted_blocks = len(blocks_seq2)
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
@staticmethod
def create_immutable_chain(
block_size: int,
@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
blocks.append(prev_block)
return blocks
class TestComputedBlocksTracker:
@staticmethod
def _get_mock_allocator():
return MagicMock(spec=PrefixCachingBlockAllocator)
@staticmethod
def test_get_num_cached_tokens():
"""
Test it correctly computes the number of cached tokens for a given
sequence:
- The cache token count is derived from the number of cached blocks.
- The cache token count is updated when the allocator is updated.
- When a sequence is removed, the cache token count should be updated
accordingly.
# TODO(rickyx): This behaviour for prefill sequence is a hack until
we fix the computed blocks tracking.
- The cache token count for prefill sequence doesn't change while
the sequence is in continuous prefill (chunked prefill).
"""
block_size = 4
mock_allocator = TestComputedBlocksTracker._get_mock_allocator()
tracker = ComputedBlocksTracker(
allocator=mock_allocator,
block_size=block_size,
enable_caching=True,
)
# Not yet allocated.
tokens = [0, 1, 2, 3, 4, 5]
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = []
assert tracker.get_num_cached_tokens(seq1) == 0
mock_allocator.find_cached_blocks_prefix.return_value = [
None
] # 1 block cached.
# Result is cached for prefill sequence.
assert tracker.get_num_cached_tokens(seq1) == 0
# Mark the sequence as non-prefill.
seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed.
assert not seq1.is_prefill()
# Recomputes for decoding sequence.
assert tracker.get_num_cached_tokens(seq1) == 4
# Append new tokens to the sequence.
num_new_tokens = 3
for i in range(num_new_tokens):
seq1.append_token_id(i, {i: Logprob(logprob=0.0)})
assert tracker.get_num_cached_tokens(seq1) == 4
# Update the allocator.
mock_allocator.find_cached_blocks_prefix.return_value = [
None
] * 2 # 2 blocks cached.
assert tracker.get_num_cached_tokens(seq1) == 8
# Remove the sequence.
tracker.remove_seq(seq1.seq_id)
# Re-create the sequence with the same request id to simulate recompute.
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = [
] # no cached block
assert tracker.get_num_cached_tokens(seq1) == 0
@staticmethod
def test_correct_block_hash():
"""
Test that the block hash is correctly computed for a sequence (should
match the underlying block allocator's block hash). So the number of
cached tokens is correctly retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]
tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)
tokens = list(range(block_size * 4)) # 4 blocks.
seq = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
)
allocator.mark_blocks_as_computed([])
assert tracker.get_num_cached_tokens(seq) == len(tokens)

View File

@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup
from .utils import (append_new_token, append_new_token_seq_group,
create_dummy_prompt, get_sequence_groups,
schedule_and_update_computed_tokens)
from .utils import (append_new_token, append_new_token_seq,
append_new_token_seq_group, create_dummy_prompt,
get_sequence_groups, schedule_and_update_computed_tokens)
def test_scheduler_add_seq_group():
@ -305,6 +305,8 @@ def initialize_scheduler(
block_size=4,
num_cpu_blocks=8,
num_gpu_blocks=8,
enable_prefix_caching=False,
enable_chunked_prefill=False,
):
block_size = block_size
scheduler_config = SchedulerConfig(
@ -312,8 +314,15 @@ def initialize_scheduler(
max_num_batched_tokens=max_token_budget,
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
)
cache_config = CacheConfig(
block_size,
1.0,
1,
"auto",
enable_prefix_caching=enable_prefix_caching,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = num_cpu_blocks
cache_config.num_gpu_blocks = num_gpu_blocks
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
@ -800,3 +809,165 @@ def test_scheduling_budget():
assert budget.num_curr_seqs == 0
budget.subtract_num_seqs(seq_group.request_id, 2)
assert budget.num_curr_seqs == 0
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
"""
Test the below scenario:
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
The test verifies the below scenarios:
1. SeqA is first scheduled.
2. SeqB and SeqC can be prefilled together in a single schedule round
even though there are not enough token budgets to prefill both without
considering prefix caching.
"""
block_size = 4
max_num_batched_tokens = 12
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_token_budget=max_num_batched_tokens,
max_num_seqs=max_seq_group,
max_model_len=max_num_batched_tokens,
enable_prefix_caching=enable_prefix_caching,
)
seqA_tokens = list(range(8))
num_shared_tokens = 4
seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
12, 16)) # Shared prefix first 4.
seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
16, 20)) # Shared prefix first 4.
seqA, seqA_group = create_dummy_prompt("0",
prompt_tokens=seqA_tokens,
block_size=block_size)
seqB, seqB_group = create_dummy_prompt("1",
prompt_tokens=seqB_tokens,
block_size=block_size)
seqC, seqC_group = create_dummy_prompt("2",
prompt_tokens=seqC_tokens,
block_size=block_size)
# Schedule seqA prefill.
scheduler.add_seq_group(seqA_group)
metas, out, _ = scheduler.schedule()
assert (len(out.scheduled_seq_groups) == 1
and out.scheduled_seq_groups[0].seq_group == seqA_group)
assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)
# Schedule seqA decode.
append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
metas, out, _ = scheduler.schedule()
assert len(out.scheduled_seq_groups) == 1
assert out.scheduled_seq_groups[0].seq_group == seqA_group
assert out.scheduled_seq_groups[0].token_chunk_size == 1
# Schedule seqB and seqC prefills should work with prefix caching.
scheduler.add_seq_group(seqB_group)
scheduler.add_seq_group(seqC_group)
metas, out, _ = scheduler.schedule()
if enable_prefix_caching:
assert len(out.scheduled_seq_groups) == 2
assert set([
out.scheduled_seq_groups[0].seq_group,
out.scheduled_seq_groups[1].seq_group,
]) == set([seqB_group, seqC_group])
assert len(metas) == 2
for meta in metas:
assert meta.token_chunk_size == 8
assert (len(meta.computed_block_nums) == num_shared_tokens //
block_size) # 1 Block for the 8 tokens.
else:
assert len(out.scheduled_seq_groups) == 1
assert len(metas) == 1
assert metas[0].token_chunk_size == 8
assert len(metas[0].computed_block_nums) == 0 # No blocks computed.
def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
"""
This test verifies that we don't schedule new prefills if there's already
a continuous prefill in progress even though the new prefills with shared
prefix can fit in the token budget:
- SeqA is being chunked prefill.
- SeqB with the same prompt shouldn't be scheduled for prefill even though
there's enough token budget to prefill the cached tokens.
- Neither should seqC be scheduled.
- When seqA is in decoding phase, seqB and seqC can be scheduled.
- Entire seqB should be prefilled since it's a full prefix cache hit.
- SeqC would be partially prefilled with the prefix shared, and the
remaining unique tokens would be prefilled (rounded down to be
block-size aligned).
"""
block_size = 2
max_num_batched_tokens = 4
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_token_budget=max_num_batched_tokens,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
enable_chunked_prefill=True,
)
seqA_tokens = list(range(8))
seqB_tokens = seqA_tokens
seqC_shared_prefix_len = 4
seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))
seqA, seqA_group = create_dummy_prompt("0",
prompt_tokens=seqA_tokens,
block_size=block_size)
seqB, seqB_group = create_dummy_prompt("1",
prompt_tokens=seqB_tokens,
block_size=block_size)
# Chunked prefill seqA.
scheduler.add_seq_group(seqA_group)
metas, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.scheduled_seq_groups[0].seq_group == seqA_group
assert out.scheduled_seq_groups[0].token_chunk_size == 4
# seqB should not be scheduled with ongoing prefills.
scheduler.add_seq_group(seqB_group)
metas, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.scheduled_seq_groups[0].seq_group == seqA_group
assert out.scheduled_seq_groups[0].token_chunk_size == 4
# both seqB and seqC can now be scheduled with seqA is over.
# seqA is in decoding phase.
append_new_token_seq(seqA, 999)
seqC, seqC_group = create_dummy_prompt("2",
prompt_tokens=seqC_tokens,
block_size=block_size)
scheduler.add_seq_group(seqC_group)
metas, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 3
metas = {meta.request_id: meta for meta in metas}
assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode
assert (metas[seqB_group.request_id].token_chunk_size == 8
) # Fully cached prefill
assert (
metas[seqC_group.request_id].token_chunk_size == 6
), "A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."

View File

@ -1,17 +1,20 @@
import time
from typing import List, Optional
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)
def create_dummy_prompt(
request_id: str,
prompt_length: int,
prompt_length: int = -1,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
best_of: int = 1,
@ -26,6 +29,7 @@ def create_dummy_prompt(
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
@ -42,6 +46,15 @@ def create_dummy_prompt(
return prompt, seq_group
def create_dummy_sequence(request_id: int, token_ids: List[int],
block_size: int) -> Sequence:
return Sequence(
seq_id=request_id,
inputs=token_inputs(token_ids),
block_size=block_size,
)
def create_dummy_prompt_encoder_decoder(
request_id: str,
decoder_prompt_length: int,
@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
for s in out.scheduled_seq_groups:
s.seq_group.update_num_computed_tokens(s.token_chunk_size)
return metas, out
def append_new_token_seq(seq: Sequence, token_id: int):
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
class SchedulerProxy:
"""
A proxy class to forward calls to the scheduler.
"""
def __init__(self, scheduler: Scheduler):
self.scheduler_ = scheduler
self.call_history: Dict[str, List[Any]] = defaultdict(list)
def __getattr__(self, name: str) -> Any:
def wrapper(*args, **kwargs):
result = getattr(self.scheduler_, name)(*args, **kwargs)
self.call_history[name].append((args, kwargs, result))
return result
return wrapper
def last_schedule_ret(
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
_, _, ret = self.call_history["schedule"][-1]
return ret

View File

@ -2,10 +2,15 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest
from tests.conftest import VllmRunner
from tests.core.utils import SchedulerProxy, create_dummy_prompt
from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams, TokensPrompt
from vllm.core.scheduler import Scheduler
from vllm.engine.llm_engine import LLMEngine
from ..models.utils import check_outputs_equal
@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
@pytest.mark.parametrize("block_size", [16])
def test_mixed_requests(
hf_runner,
@ -37,6 +43,7 @@ def test_mixed_requests(
dtype: str,
max_tokens: int,
cached_position: int,
enable_chunked_prefill: bool,
block_size: int,
monkeypatch,
) -> None:
@ -55,6 +62,7 @@ def test_mixed_requests(
model,
dtype=dtype,
enable_prefix_caching=True,
enable_chunked_prefill=enable_chunked_prefill,
block_size=block_size,
) as vllm_model:
# Run the first prompt so the cache is populated
@ -72,13 +80,13 @@ def test_mixed_requests(
block_size) * block_size
else:
expected_num_cached_tokens = 0
assert req_outputs[
i].num_cached_tokens == expected_num_cached_tokens
assert (
req_outputs[i].num_cached_tokens == expected_num_cached_tokens)
vllm_outputs = [
(output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text) for output in req_outputs
]
vllm_outputs = [(
output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text,
) for output in req_outputs]
check_outputs_equal(
outputs_0_lst=hf_outputs,
@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
for prompt in UNSTABLE_PROMPT_SEQUENCE:
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
SamplingParams(max_tokens=1))
@pytest.mark.parametrize("model", MODELS)
def test_fully_cached_prefill_needs_uncached_token(model):
block_size = 16
max_num_batched_tokens = 16
num_output_tokens = 5
# Make a vllm engine
runner = VllmRunner(
model_name=model,
gpu_memory_utilization=0.7,
enable_chunked_prefill=True,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_batched_tokens,
)
engine: LLMEngine = runner.model.llm_engine
scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore
engine.scheduler[0] = scheduler
# SeqA
seqA_tokens = list(range(2 * block_size))
seqA, seq_groupA = create_dummy_prompt(
request_id="0",
prompt_tokens=seqA_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)
scheduler.add_seq_group(seq_groupA)
assert seqA.data.get_num_computed_tokens() == 0
# Prefill seqA
while not seqA.is_finished():
engine.step()
# seqB
seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1
seqB, seq_groupB = create_dummy_prompt(
request_id="1",
prompt_tokens=seqB_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)
# seqC is the same as seqA
seqC, seq_groupC = create_dummy_prompt(
request_id="2",
prompt_tokens=seqA_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)
scheduler.add_seq_group(seq_groupB)
scheduler.add_seq_group(seq_groupC)
# Even seqC is fully cached, it should not be prefilled since we
# require at least 1 uncached token.
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupB.request_id)
assert (sched_out.scheduled_seq_groups[0].token_chunk_size ==
max_num_batched_tokens)
# When seqB is finished, seqC could be prefilled.
while not seqB.is_finished():
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupB.request_id)
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupC.request_id)
assert sched_out.scheduled_seq_groups[0].token_chunk_size == len(
seqA_tokens)

View File

@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_computed_block_ids(
prev_computed_block_ids, block_ids, skip_last_block_id)
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
self._swap_mapping.clear()
return list(mapping.items())
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
device: Device = Device.GPU,
) -> List[int]:
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
class NullBlock(Block):
"""

View File

@ -159,12 +159,6 @@ class BlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
@ -192,6 +186,13 @@ class BlockAllocator(ABC):
class NoFreeBlocksError(ValueError):
pass
@abstractmethod
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
) -> List[int]:
pass
class DeviceAwareBlockAllocator(ABC):
@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
pass
@abstractmethod
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device) -> List[Block]:
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device,
) -> List[Block]:
pass
@abstractmethod
@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
device: Device = Device.GPU,
) -> List[int]:
pass

View File

@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
"""
pass
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
"""No prefix caching here => return empty list
"""
return []
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float:
return -1
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix

View File

@ -1,13 +1,18 @@
"""Token blocks."""
import sys
from bisect import bisect_left
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
Tuple)
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.sequence import Sequence
PrefixHash = int
@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else:
return block_id in self.evictor
def get_computed_block_ids(self,
prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool = True) -> List[int]:
prev_prefix_size = len(prev_computed_block_ids)
cur_size = len(block_ids)
if skip_last_block_id:
cur_size -= 1
# Sanity checks
assert cur_size >= 0
assert prev_prefix_size <= cur_size
ret = prev_computed_block_ids
for i in range(prev_prefix_size, cur_size):
block_id = block_ids[i]
if self.block_is_computed(block_id):
ret.append(block_id)
return ret
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block.block_id = block_id # Assign block_id
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
"""
Given a list of block hashes, return the prefix of the block hashes that
are all cached.
Since a block's block hash includes the hashes of all previous blocks,
and we only allocate/deallocate blocks in the entire sequence, so if a
block is cached, then all previous blocks are also cached. With this
property, we can use binary search to find the prefix of cached blocks.
Args:
block_hashes (List[int]): The list of block hashes.
Returns:
List[int]: The prefix of the `block_hashes` that are cached.
"""
def _block_is_cached(block_hash: PrefixHash) -> bool:
if block_hash not in self._cached_blocks:
return False
cached_block_id = self._cached_blocks[block_hash]
# We only consider the blocks that are marked as computed.
return self.block_is_computed(cached_block_id)
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
# python <= 3.10 don't have the key argument
if sys.version_info < (3, 10):
a = [key(e) for e in a]
return bisect_left(a, x)
else:
return bisect_left(a, x, key=key)
# Look for the first block that's not cached, and returns the prefix
# i.e. blocks that are cached.
idx = _bisect_left(block_hashes,
True,
key=lambda x: not _block_is_cached(x))
return block_hashes[:idx]
class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.
@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
class ComputedBlocksTracker:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
"""
Tracks the computed blocks for each sequence.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Internally, it maintains a map from sequence id to the list of block hashes
for the sequence. We cache the hashes of the full blocks for each sequence,
and make sure the hash is calculated in the same way as the allocator.
When a sequence is being decoded, we also update the sequence's hash
accordingly and incrementally.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
From the sequence hash, with prefix caching enabled, we could also calculate
the number of cached tokens for the sequence by looking up the number of
cached block hashes in the allocator.
"""
def __init__(self, allocator):
def __init__(
self,
allocator: DeviceAwareBlockAllocator,
block_size: int,
enable_caching: bool,
):
self._allocator = allocator
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
bool]] = {}
self._block_size = block_size
self._enable_caching = enable_caching
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._cached_computed_seq_blocks
self._cached_computed_seq_blocks[seq_id] = ([], False)
# A map from seq_id to the list of block hashes for the
# sequence. This is so that we don't have to recompute the block hashes
# for the sequence when we need to check if the sequence is cached.
# Note a block that's not full will not have its hash calculated and
# recorded.
self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {}
# A map from seq_id to the number of tokens that are cached for the
# sequence.
# We need this so that a sequence in continuous prefill doesn't
# accidentally see its cached token count change. See comments in
# `get_num_cached_tokens` for more details.
self._seq_id_to_num_tokens_computed: Dict[int, int] = {}
def _update_seq_hashes(self, seq: Sequence) -> None:
"""Incrementally update the sequence's block hashes and record them."""
assert self._enable_caching
block_hashes_recorded = self._seq_id_to_blocks_hashes.get(
seq.seq_id, [])
cur_num_blocks_recorded = len(block_hashes_recorded)
token_ids = seq.get_token_ids()
assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, (
f"The sequence has {len(token_ids)} tokens, but"
f" already recorded {cur_num_blocks_recorded} blocks. "
"This should not happen since we assume blocks are "
"only appended other than recomputation. When the sequence is "
"recomputed, we should have removed the info of the old blocks.")
# Update the computed block hashes for the sequence. Since only full
# blocks are considered as "computed", we take floor here.
num_computed_blocks = len(token_ids) // self._block_size
# We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across
# sequences of prefixes.
prev_block_hash = (None if cur_num_blocks_recorded == 0 else
block_hashes_recorded[-1])
# Only update the computed block hashes for the new blocks
for i in range(cur_num_blocks_recorded, num_computed_blocks):
assert len(token_ids) >= (i + 1) * self._block_size
block_token_ids = token_ids[i * self._block_size:(i + 1) *
self._block_size]
# This has to be kept in sync with the allocator's hash
# calculation.
block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None,
prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids,
)
block_hashes_recorded.append(block_hash)
prev_block_hash = block_hash
self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded
def get_num_cached_tokens(self, seq: Sequence) -> int:
if not self._enable_caching:
return 0
# We always try to update the sequence hashes on the fly.
# This is to ensure that we don't miss any cached tokens for the
# sequence during decode.
# This routine should only update hash for any new blocks too.
self._update_seq_hashes(seq)
num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get(
seq.seq_id, None)
# TODO(rickyx): This hack could be removed once we mark blocks as
# computed correctly with chunked prefills.
if num_computed_tokens_prev is not None and seq.is_prefill():
# For a sequence that is still in prefill, we don't
# recompute the number of cached tokens.
# This also handles correctly chunked prefill since currently
# we mark blocks as computed even if the sequence is still partially
# prefilled. So a continuously prefilled sequence should not
# see its cached token count change while running.
return num_computed_tokens_prev
block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id]
# This is O(logN), where N is the number of blocks.
num_cached_blocks = len(
self._allocator.find_cached_blocks_prefix(block_hashes))
num_cached_tokens = num_cached_blocks * self._block_size
self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens
return num_cached_tokens
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._cached_computed_seq_blocks
del self._cached_computed_seq_blocks[seq_id]
"""Stop tracking the sequence."""
if not self._enable_caching:
return
assert seq_id in self._seq_id_to_blocks_hashes
del self._seq_id_to_blocks_hashes[seq_id]
def get_cached_computed_blocks_and_update(
self, seq_id: int, block_ids: List[int]) -> List[int]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert seq_id in self._cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
seq_id]
if has_gap:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks = len(block_ids) - 1
assert num_cur_blocks >= 0
if len(prev_computed_block_ids) >= num_cur_blocks:
# Cache HIT
assert len(prev_computed_block_ids) == num_cur_blocks
return prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
prev_computed_block_ids,
block_ids,
skip_last_block_id=
True, # We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
has_gap = len(computed_block_ids) < num_cur_blocks
# Record
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
has_gap)
return computed_block_ids
assert seq_id in self._seq_id_to_num_tokens_computed
del self._seq_id_to_num_tokens_computed[seq_id]
class LastAccessBlocksTracker:

View File

@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator)
self.block_allocator, self.block_size, self.enable_caching)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[seq.seq_id] = block_table
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence.
@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[seq.seq_id] = block_table.fork()
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence
@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
"""
computed_seq_block_ids = []
for seq in seqs:
computed_seq_block_ids.append(
self._computed_blocks_tracker.
get_cached_computed_blocks_and_update(
seq.seq_id,
self.block_tables[seq.seq_id].physical_block_ids))
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
num_cached_tokens = (
self._computed_blocks_tracker.get_num_cached_tokens(seq))
assert num_cached_tokens % self.block_size == 0
num_cached_blocks = num_cached_tokens // self.block_size
computed_block_ids = all_blocks[:num_cached_blocks]
computed_seq_block_ids.append(computed_block_ids)
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup,
@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
return AllocStatus.OK
else:
return AllocStatus.LATER
def get_num_cached_tokens(self, seq: Sequence) -> int:
"""Get the number of tokens in blocks that are already computed and
cached in the block manager for the sequence.
"""
return self._computed_blocks_tracker.get_num_cached_tokens(seq)

View File

@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass

View File

@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0

View File

@ -56,11 +56,16 @@ class SchedulingBudget:
max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
# Number of cached tokens in the batch.
_num_cached_tokens: int = 0
# Number of actual non-cached tokens in the batch.
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
assert num_new_tokens != 0
# We allow num_new_tokens to be 0 when the entire sequence has
# been cached.
assert num_new_tokens >= 0
assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
@ -68,12 +73,18 @@ class SchedulingBudget:
def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
def add_num_batched_tokens(self,
req_id: str,
num_batched_tokens: int,
num_cached_tokens: int = 0):
if req_id in self._request_ids_num_batched_tokens:
return
assert num_cached_tokens >= 0
assert num_batched_tokens >= 0
self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens
self._num_cached_tokens += num_cached_tokens
def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int):
@ -101,6 +112,10 @@ class SchedulingBudget:
def num_curr_seqs(self):
return self._num_curr_seqs
@property
def num_cached_tokens(self):
return self._num_cached_tokens
@dataclass
class ScheduledSequenceGroup:
@ -541,9 +556,19 @@ class Scheduler:
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
# We discard the cached tokens info here because we don't need it
# for running sequence:
# 1. If a sequence is running with chunked prefill, the cached
# tokens info was already used for the first prefill.
# 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is
# irrelevant.
num_uncached_new_tokens, _ = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking,
budget))
num_running_tokens = num_uncached_new_tokens
if num_running_tokens == 0:
# No budget => Stop
break
@ -715,13 +740,15 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.SWAPPED,
enable_chunking, budget)
num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.SWAPPED, enable_chunking,
budget))
if (num_new_tokens == 0
or not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
if num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
break
if lora_int_id > 0 and curr_loras is not None:
@ -732,12 +759,19 @@ class Scheduler:
is_prefill = seq_group.is_prefill()
if is_prefill:
prefill_seq_groups.append(
ScheduledSequenceGroup(seq_group,
token_chunk_size=num_new_tokens))
ScheduledSequenceGroup(
seq_group,
token_chunk_size=num_new_tokens_uncached +
num_new_tokens_cached,
))
else:
decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_batched_tokens(
seq_group.request_id,
num_batched_tokens=num_new_tokens_uncached,
num_cached_tokens=num_new_tokens_cached,
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
swapped_queue.extendleft(leftover_swapped)
@ -803,26 +837,30 @@ class Scheduler:
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING,
False, budget)
num_new_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, False, budget))
#Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens and can_allocate == AllocStatus.OK
and budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
if (num_new_tokens_uncached > 0
and can_allocate == AllocStatus.OK
and budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
)):
break
#Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens = self._get_num_new_tokens(
vseq_group, SequenceStatus.RUNNING, False, budget)
budget.subtract_num_batched_tokens(vseq_group.request_id,
num_running_tokens)
num_running_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens(
vseq_group, SequenceStatus.RUNNING, False, budget))
budget.subtract_num_batched_tokens(
vseq_group.request_id, num_running_tokens_uncached)
num_running_seqs = vseq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
@ -882,9 +920,12 @@ class Scheduler:
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING,
enable_chunking, budget)
num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, enable_chunking,
budget))
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens
@ -935,10 +976,18 @@ class Scheduler:
waiting_queue.popleft()
continue
if (budget.num_batched_tokens >=
self.scheduler_config.max_num_batched_tokens):
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
break
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_new_tokens == 0
or not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
if num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
break
# Can schedule this request.
@ -967,7 +1016,11 @@ class Scheduler:
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_batched_tokens(
seq_group.request_id,
num_batched_tokens=num_new_tokens_uncached,
num_cached_tokens=num_new_tokens_cached,
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
# Queue requests that couldn't be scheduled.
@ -1075,7 +1128,8 @@ class Scheduler:
return SchedulerOutputs(
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens,
num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
@ -1119,7 +1173,6 @@ class Scheduler:
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
# Schedule new prefills.
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
@ -1157,7 +1210,8 @@ class Scheduler:
num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)),
num_batched_tokens=budget.num_batched_tokens,
num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy +
@ -1584,64 +1638,178 @@ class Scheduler:
return self.scheduler_config.num_lookahead_slots
def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.
def _get_num_new_uncached_and_cached_tokens(
self,
seq_group: SequenceGroup,
status: SequenceStatus,
enable_chunking: bool,
budget: SchedulingBudget,
) -> Tuple[int, int]:
"""
Returns the number of new uncached and cached tokens to schedule for a
given sequence group that's in a given `status`.
The API could chunk the number of tokens to compute based on `budget`
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget.
Returns (0, 0) if the new token cannot be computed due to token budget.
The cached tokens's blocks are already computed, and the attention
backend will reuse the cached blocks rather than recomputing them. So
the scheduler could schedule these cached tokens "for free".
Args:
seq_group: The sequence group to get the number of new tokens to
schedule.
status: The status of the sequences to get the number of new tokens
to schedule.
enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute.
Returns:
A tuple of two ints. The first int is the number of new uncached
tokens to schedule. The second int is the number of cached tokens.
If no more new tokens can be scheduled, returns (0, 0).
"""
num_new_tokens = 0
num_cached_new_tokens = 0
num_uncached_new_tokens = 0
seqs = seq_group.get_seqs(status=status)
# Compute the number of new uncached and cached tokens for
# each sequence.
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if not seq.is_prefill():
# Decode sequences should always just have 1 uncached token
# TODO(rickyx): Actually is this still correct for multi-step?
num_uncached_new_tokens += 1
continue
num_computed_tokens_seq = seq.get_num_computed_tokens()
all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
if not self.cache_config.enable_prefix_caching:
# If prefix caching is not enabled, all new tokens are uncached.
num_uncached_new_tokens += all_num_new_tokens_seq
continue
# NOTE: the cache token might be currently in a block that's in an
# evictor meaning that it's not yet allocated. However, we don't
# exclude such tokens in the cache count because it will be
# guaranteed to be allocated later if the sequence can be allocated.
num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
seq)
# Sanity check.
if num_cached_tokens_seq < num_computed_tokens_seq:
# This should only happen with chunked prefill, and
# the seq is still in prefill. The `num_cached_tokens_seq`
# is the value we calculated on scheduling the first prefill.
# For subsequent continuous prefill steps, we cached the
# number of cache tokens for the sequence so the cached token
# count could be less than the number of computed tokens.
# See comments on `ComputedBlocksTracker` for more details.
assert (
seq.is_prefill() and seq.status == SequenceStatus.RUNNING
and self.scheduler_config.chunked_prefill_enabled
), ("Number of cached tokens should not be less than the "
"number of computed tokens for a sequence that's still "
f"in prefill. But there are {num_cached_tokens_seq} cached "
f"tokens and {num_computed_tokens_seq} computed tokens "
f"for sequence {seq.seq_id}.")
num_cached_new_tokens_seq = max(
0, num_cached_tokens_seq - num_computed_tokens_seq)
num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
num_cached_new_tokens_seq)
num_uncached_new_tokens += num_uncached_new_tokens_seq
num_cached_new_tokens += num_cached_new_tokens_seq
if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
# For a fully cached hit sequence, we actually need to recompute the
# last token. So we need at least 1 uncached token to schedule.
# See ModelRunner._compute_for_prefix_cache_hit for more details.
num_uncached_new_tokens = 1
num_cached_new_tokens -= 1
if enable_chunking and len(seqs) == 1:
remaining_token_budget = budget.remaining_token_budget()
if self.scheduler_config.is_multi_step:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if num_new_tokens > self._get_prompt_limit(seq_group):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else:
num_new_tokens = 0 \
if num_new_tokens > remaining_token_budget \
else num_new_tokens
elif self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
block_size = self.cache_config.block_size
remainder = budget.token_budget % block_size
if remainder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {remainder}")
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
self.scheduler_config,
self.cache_config,
budget,
self._get_prompt_limit(seq_group),
num_uncached_new_tokens,
)
return num_uncached_new_tokens, num_cached_new_tokens
@staticmethod
def _chunk_new_tokens_to_schedule(
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
budget: SchedulingBudget,
prompt_limit: int,
num_new_tokens: int,
) -> int:
"""
Chunks the number of new tokens to schedule based on the budget when
chunked prefill is enabled.
Args:
scheduler_config: The scheduler config.
cache_config: The cache config.
budget: The budget to chunk the number of tokens to compute.
prompt_limit: The maximum number of tokens allowed in a prompt.
num_new_tokens: The number of new tokens to schedule.
Returns:
The number of new tokens to schedule after chunking.
"""
remaining_token_budget = budget.remaining_token_budget()
if scheduler_config.is_multi_step:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if num_new_tokens > prompt_limit:
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
return num_new_tokens
return (0 if num_new_tokens > remaining_token_budget else
num_new_tokens)
if cache_config.enable_prefix_caching:
# Adjust the remaining token budget to be divisible by the block
# size when prefix caching is enabled.
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
block_size = cache_config.block_size
remainder = budget.token_budget % block_size
if remainder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {remainder}")
# Round down to block size.
remaining_token_budget = (remaining_token_budget // block_size *
block_size)
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens

View File

@ -579,6 +579,9 @@ class Sequence:
return 1
return self.data.get_num_uncomputed_tokens()
def get_num_computed_tokens(self) -> int:
return self.data.get_num_computed_tokens()
def is_prefill(self) -> bool:
return self.data.stage == SequenceStage.PREFILL