Prefix Cache Aware Scheduling [1/n] (#10128)

Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
Ricky Xu 2024-11-22 21:15:55 -08:00 committed by GitHub
parent 7c25fe45a6
commit 4634a89d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 962 additions and 236 deletions

View File

@ -5,9 +5,14 @@ from unittest.mock import MagicMock
import pytest import pytest
from tests.core.utils import create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock, from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
PrefixCachingBlock,
PrefixCachingBlockAllocator) PrefixCachingBlockAllocator)
from vllm.sequence import Logprob
from vllm.utils import Device
class TestPrefixCachingBlock: class TestPrefixCachingBlock:
@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
token_ids=common_token_ids, token_ids=common_token_ids,
allocator=allocator, allocator=allocator,
) )
block_ids = [block.block_id for block in blocks] block_hashes = [block.content_hash for block in blocks]
# The allocated blocks should be marked as touched # The allocated blocks should be marked as touched
# but not computed. # but not computed.
computed_block_ids = allocator.get_computed_block_ids( computed_block_ids = allocator.find_cached_blocks_prefix(
[], block_ids, skip_last_block_id=False) block_hashes)
assert len(computed_block_ids) == 0 assert len(computed_block_ids) == 0
allocator.mark_blocks_as_computed([]) allocator.mark_blocks_as_computed([])
computed_block_ids = allocator.get_computed_block_ids( computed_block_ids = allocator.find_cached_blocks_prefix(
[], block_ids, skip_last_block_id=False) block_hashes=block_hashes)
assert len(computed_block_ids) == common_blocks assert len(computed_block_ids) == common_blocks
@staticmethod
def test_find_cached_blocks_prefix():
"""
This test verifies the behavior of find_cached_blocks_prefix.
"""
block_size = 4
num_blocks = 8
total_test_blocks = 12
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(total_test_blocks * block_size))
block_tokens_seq1 = token_ids[:num_blocks * block_size]
blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq1,
allocator=allocator,
)
block_hashes_seq1 = [block.content_hash for block in blocks_seq1]
allocator.mark_blocks_as_computed([])
# All blocks should be cached.
cached_blocks_seq1 = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks_seq1) == num_blocks
# Free the first sequence.
for block in blocks_seq1:
allocator.free(block)
# All blocks should be still be cached if not required to be allocated.
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == num_blocks
block_tokens_seq2 = token_ids[num_blocks * block_size:]
blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq2,
allocator=allocator,
)
block_hashes_seq2 = [block.content_hash for block in blocks_seq2]
allocator.mark_blocks_as_computed([])
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq2)
assert len(cached_blocks) == len(blocks_seq2)
# Half of the blocks from seq1 should still be cached.
num_evicted_blocks = len(blocks_seq2)
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
@staticmethod @staticmethod
def create_immutable_chain( def create_immutable_chain(
block_size: int, block_size: int,
@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
blocks.append(prev_block) blocks.append(prev_block)
return blocks return blocks
class TestComputedBlocksTracker:
@staticmethod
def _get_mock_allocator():
return MagicMock(spec=PrefixCachingBlockAllocator)
@staticmethod
def test_get_num_cached_tokens():
"""
Test it correctly computes the number of cached tokens for a given
sequence:
- The cache token count is derived from the number of cached blocks.
- The cache token count is updated when the allocator is updated.
- When a sequence is removed, the cache token count should be updated
accordingly.
# TODO(rickyx): This behaviour for prefill sequence is a hack until
we fix the computed blocks tracking.
- The cache token count for prefill sequence doesn't change while
the sequence is in continuous prefill (chunked prefill).
"""
block_size = 4
mock_allocator = TestComputedBlocksTracker._get_mock_allocator()
tracker = ComputedBlocksTracker(
allocator=mock_allocator,
block_size=block_size,
enable_caching=True,
)
# Not yet allocated.
tokens = [0, 1, 2, 3, 4, 5]
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = []
assert tracker.get_num_cached_tokens(seq1) == 0
mock_allocator.find_cached_blocks_prefix.return_value = [
None
] # 1 block cached.
# Result is cached for prefill sequence.
assert tracker.get_num_cached_tokens(seq1) == 0
# Mark the sequence as non-prefill.
seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed.
assert not seq1.is_prefill()
# Recomputes for decoding sequence.
assert tracker.get_num_cached_tokens(seq1) == 4
# Append new tokens to the sequence.
num_new_tokens = 3
for i in range(num_new_tokens):
seq1.append_token_id(i, {i: Logprob(logprob=0.0)})
assert tracker.get_num_cached_tokens(seq1) == 4
# Update the allocator.
mock_allocator.find_cached_blocks_prefix.return_value = [
None
] * 2 # 2 blocks cached.
assert tracker.get_num_cached_tokens(seq1) == 8
# Remove the sequence.
tracker.remove_seq(seq1.seq_id)
# Re-create the sequence with the same request id to simulate recompute.
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = [
] # no cached block
assert tracker.get_num_cached_tokens(seq1) == 0
@staticmethod
def test_correct_block_hash():
"""
Test that the block hash is correctly computed for a sequence (should
match the underlying block allocator's block hash). So the number of
cached tokens is correctly retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]
tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)
tokens = list(range(block_size * 4)) # 4 blocks.
seq = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
)
allocator.mark_blocks_as_computed([])
assert tracker.get_num_cached_tokens(seq) == len(tokens)

View File

@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup from vllm.sequence import SequenceGroup
from .utils import (append_new_token, append_new_token_seq_group, from .utils import (append_new_token, append_new_token_seq,
create_dummy_prompt, get_sequence_groups, append_new_token_seq_group, create_dummy_prompt,
schedule_and_update_computed_tokens) get_sequence_groups, schedule_and_update_computed_tokens)
def test_scheduler_add_seq_group(): def test_scheduler_add_seq_group():
@ -305,6 +305,8 @@ def initialize_scheduler(
block_size=4, block_size=4,
num_cpu_blocks=8, num_cpu_blocks=8,
num_gpu_blocks=8, num_gpu_blocks=8,
enable_prefix_caching=False,
enable_chunked_prefill=False,
): ):
block_size = block_size block_size = block_size
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
@ -312,8 +314,15 @@ def initialize_scheduler(
max_num_batched_tokens=max_token_budget, max_num_batched_tokens=max_token_budget,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_model_len=max_model_len, max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
)
cache_config = CacheConfig(
block_size,
1.0,
1,
"auto",
enable_prefix_caching=enable_prefix_caching,
) )
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = num_cpu_blocks cache_config.num_cpu_blocks = num_cpu_blocks
cache_config.num_gpu_blocks = num_gpu_blocks cache_config.num_gpu_blocks = num_gpu_blocks
scheduler = Scheduler(scheduler_config, cache_config, lora_config) scheduler = Scheduler(scheduler_config, cache_config, lora_config)
@ -800,3 +809,165 @@ def test_scheduling_budget():
assert budget.num_curr_seqs == 0 assert budget.num_curr_seqs == 0
budget.subtract_num_seqs(seq_group.request_id, 2) budget.subtract_num_seqs(seq_group.request_id, 2)
assert budget.num_curr_seqs == 0 assert budget.num_curr_seqs == 0
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
"""
Test the below scenario:
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
The test verifies the below scenarios:
1. SeqA is first scheduled.
2. SeqB and SeqC can be prefilled together in a single schedule round
even though there are not enough token budgets to prefill both without
considering prefix caching.
"""
block_size = 4
max_num_batched_tokens = 12
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_token_budget=max_num_batched_tokens,
max_num_seqs=max_seq_group,
max_model_len=max_num_batched_tokens,
enable_prefix_caching=enable_prefix_caching,
)
seqA_tokens = list(range(8))
num_shared_tokens = 4
seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
12, 16)) # Shared prefix first 4.
seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
16, 20)) # Shared prefix first 4.
seqA, seqA_group = create_dummy_prompt("0",
prompt_tokens=seqA_tokens,
block_size=block_size)
seqB, seqB_group = create_dummy_prompt("1",
prompt_tokens=seqB_tokens,
block_size=block_size)
seqC, seqC_group = create_dummy_prompt("2",
prompt_tokens=seqC_tokens,
block_size=block_size)
# Schedule seqA prefill.
scheduler.add_seq_group(seqA_group)
metas, out, _ = scheduler.schedule()
assert (len(out.scheduled_seq_groups) == 1
and out.scheduled_seq_groups[0].seq_group == seqA_group)
assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)
# Schedule seqA decode.
append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
metas, out, _ = scheduler.schedule()
assert len(out.scheduled_seq_groups) == 1
assert out.scheduled_seq_groups[0].seq_group == seqA_group
assert out.scheduled_seq_groups[0].token_chunk_size == 1
# Schedule seqB and seqC prefills should work with prefix caching.
scheduler.add_seq_group(seqB_group)
scheduler.add_seq_group(seqC_group)
metas, out, _ = scheduler.schedule()
if enable_prefix_caching:
assert len(out.scheduled_seq_groups) == 2
assert set([
out.scheduled_seq_groups[0].seq_group,
out.scheduled_seq_groups[1].seq_group,
]) == set([seqB_group, seqC_group])
assert len(metas) == 2
for meta in metas:
assert meta.token_chunk_size == 8
assert (len(meta.computed_block_nums) == num_shared_tokens //
block_size) # 1 Block for the 8 tokens.
else:
assert len(out.scheduled_seq_groups) == 1
assert len(metas) == 1
assert metas[0].token_chunk_size == 8
assert len(metas[0].computed_block_nums) == 0 # No blocks computed.
def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
"""
This test verifies that we don't schedule new prefills if there's already
a continuous prefill in progress even though the new prefills with shared
prefix can fit in the token budget:
- SeqA is being chunked prefill.
- SeqB with the same prompt shouldn't be scheduled for prefill even though
there's enough token budget to prefill the cached tokens.
- Neither should seqC be scheduled.
- When seqA is in decoding phase, seqB and seqC can be scheduled.
- Entire seqB should be prefilled since it's a full prefix cache hit.
- SeqC would be partially prefilled with the prefix shared, and the
remaining unique tokens would be prefilled (rounded down to be
block-size aligned).
"""
block_size = 2
max_num_batched_tokens = 4
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_token_budget=max_num_batched_tokens,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
enable_chunked_prefill=True,
)
seqA_tokens = list(range(8))
seqB_tokens = seqA_tokens
seqC_shared_prefix_len = 4
seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))
seqA, seqA_group = create_dummy_prompt("0",
prompt_tokens=seqA_tokens,
block_size=block_size)
seqB, seqB_group = create_dummy_prompt("1",
prompt_tokens=seqB_tokens,
block_size=block_size)
# Chunked prefill seqA.
scheduler.add_seq_group(seqA_group)
metas, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.scheduled_seq_groups[0].seq_group == seqA_group
assert out.scheduled_seq_groups[0].token_chunk_size == 4
# seqB should not be scheduled with ongoing prefills.
scheduler.add_seq_group(seqB_group)
metas, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.scheduled_seq_groups[0].seq_group == seqA_group
assert out.scheduled_seq_groups[0].token_chunk_size == 4
# both seqB and seqC can now be scheduled with seqA is over.
# seqA is in decoding phase.
append_new_token_seq(seqA, 999)
seqC, seqC_group = create_dummy_prompt("2",
prompt_tokens=seqC_tokens,
block_size=block_size)
scheduler.add_seq_group(seqC_group)
metas, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 3
metas = {meta.request_id: meta for meta in metas}
assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode
assert (metas[seqB_group.request_id].token_chunk_size == 8
) # Fully cached prefill
assert (
metas[seqC_group.request_id].token_chunk_size == 6
), "A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."

View File

@ -1,17 +1,20 @@
import time import time
from typing import List, Optional from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
from vllm import SamplingParams from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)
def create_dummy_prompt( def create_dummy_prompt(
request_id: str, request_id: str,
prompt_length: int, prompt_length: int = -1,
block_size: Optional[int] = None, block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
best_of: int = 1, best_of: int = 1,
@ -26,6 +29,7 @@ def create_dummy_prompt(
# Create dummy prompt sequence with tokens 0...block_size-1 # Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size". # and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt = Sequence(int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str), inputs=token_inputs(prompt_tokens, prompt=prompt_str),
@ -42,6 +46,15 @@ def create_dummy_prompt(
return prompt, seq_group return prompt, seq_group
def create_dummy_sequence(request_id: int, token_ids: List[int],
block_size: int) -> Sequence:
return Sequence(
seq_id=request_id,
inputs=token_inputs(token_ids),
block_size=block_size,
)
def create_dummy_prompt_encoder_decoder( def create_dummy_prompt_encoder_decoder(
request_id: str, request_id: str,
decoder_prompt_length: int, decoder_prompt_length: int,
@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
def schedule_and_update_computed_tokens(scheduler): def schedule_and_update_computed_tokens(scheduler):
metas, out, _ = scheduler.schedule() metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas): for s in out.scheduled_seq_groups:
s.seq_group.update_num_computed_tokens(meta.token_chunk_size) s.seq_group.update_num_computed_tokens(s.token_chunk_size)
return metas, out return metas, out
def append_new_token_seq(seq: Sequence, token_id: int):
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size) seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)}) seq.append_token_id(token_id, {token_id: Logprob(token_id)})
class SchedulerProxy:
"""
A proxy class to forward calls to the scheduler.
"""
def __init__(self, scheduler: Scheduler):
self.scheduler_ = scheduler
self.call_history: Dict[str, List[Any]] = defaultdict(list)
def __getattr__(self, name: str) -> Any:
def wrapper(*args, **kwargs):
result = getattr(self.scheduler_, name)(*args, **kwargs)
self.call_history[name].append((args, kwargs, result))
return result
return wrapper
def last_schedule_ret(
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
_, _, ret = self.call_history["schedule"][-1]
return ret

View File

@ -2,10 +2,15 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`. Run `pytest tests/prefix_caching/test_prefix_caching.py`.
""" """
import pytest import pytest
from tests.conftest import VllmRunner
from tests.core.utils import SchedulerProxy, create_dummy_prompt
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams, TokensPrompt from vllm import SamplingParams, TokensPrompt
from vllm.core.scheduler import Scheduler
from vllm.engine.llm_engine import LLMEngine
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1]) @pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
def test_mixed_requests( def test_mixed_requests(
hf_runner, hf_runner,
@ -37,6 +43,7 @@ def test_mixed_requests(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
cached_position: int, cached_position: int,
enable_chunked_prefill: bool,
block_size: int, block_size: int,
monkeypatch, monkeypatch,
) -> None: ) -> None:
@ -55,6 +62,7 @@ def test_mixed_requests(
model, model,
dtype=dtype, dtype=dtype,
enable_prefix_caching=True, enable_prefix_caching=True,
enable_chunked_prefill=enable_chunked_prefill,
block_size=block_size, block_size=block_size,
) as vllm_model: ) as vllm_model:
# Run the first prompt so the cache is populated # Run the first prompt so the cache is populated
@ -72,13 +80,13 @@ def test_mixed_requests(
block_size) * block_size block_size) * block_size
else: else:
expected_num_cached_tokens = 0 expected_num_cached_tokens = 0
assert req_outputs[ assert (
i].num_cached_tokens == expected_num_cached_tokens req_outputs[i].num_cached_tokens == expected_num_cached_tokens)
vllm_outputs = [ vllm_outputs = [(
(output.prompt_token_ids + list(output.outputs[0].token_ids), output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text) for output in req_outputs output.prompt + output.outputs[0].text,
] ) for output in req_outputs]
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
for prompt in UNSTABLE_PROMPT_SEQUENCE: for prompt in UNSTABLE_PROMPT_SEQUENCE:
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
SamplingParams(max_tokens=1)) SamplingParams(max_tokens=1))
@pytest.mark.parametrize("model", MODELS)
def test_fully_cached_prefill_needs_uncached_token(model):
block_size = 16
max_num_batched_tokens = 16
num_output_tokens = 5
# Make a vllm engine
runner = VllmRunner(
model_name=model,
gpu_memory_utilization=0.7,
enable_chunked_prefill=True,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_batched_tokens,
)
engine: LLMEngine = runner.model.llm_engine
scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore
engine.scheduler[0] = scheduler
# SeqA
seqA_tokens = list(range(2 * block_size))
seqA, seq_groupA = create_dummy_prompt(
request_id="0",
prompt_tokens=seqA_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)
scheduler.add_seq_group(seq_groupA)
assert seqA.data.get_num_computed_tokens() == 0
# Prefill seqA
while not seqA.is_finished():
engine.step()
# seqB
seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1
seqB, seq_groupB = create_dummy_prompt(
request_id="1",
prompt_tokens=seqB_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)
# seqC is the same as seqA
seqC, seq_groupC = create_dummy_prompt(
request_id="2",
prompt_tokens=seqA_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)
scheduler.add_seq_group(seq_groupB)
scheduler.add_seq_group(seq_groupC)
# Even seqC is fully cached, it should not be prefilled since we
# require at least 1 uncached token.
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupB.request_id)
assert (sched_out.scheduled_seq_groups[0].token_chunk_size ==
max_num_batched_tokens)
# When seqB is finished, seqC could be prefilled.
while not seqB.is_finished():
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupB.request_id)
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupC.request_id)
assert sched_out.scheduled_seq_groups[0].token_chunk_size == len(
seqA_tokens)

View File

@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids) return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_computed_block_ids(
prev_computed_block_ids, block_ids, skip_last_block_id)
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU. # Prefix caching only supported on GPU.
@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
self._swap_mapping.clear() self._swap_mapping.clear()
return list(mapping.items()) return list(mapping.items())
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
device: Device = Device.GPU,
) -> List[int]:
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
class NullBlock(Block): class NullBlock(Block):
""" """

View File

@ -159,12 +159,6 @@ class BlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
@ -192,6 +186,13 @@ class BlockAllocator(ABC):
class NoFreeBlocksError(ValueError): class NoFreeBlocksError(ValueError):
pass pass
@abstractmethod
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
) -> List[int]:
pass
class DeviceAwareBlockAllocator(ABC): class DeviceAwareBlockAllocator(ABC):
@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def allocate_immutable_blocks(self, prev_block: Optional[Block], def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]], block_token_ids: List[List[int]],
device: Device) -> List[Block]: device: Device,
) -> List[Block]:
pass pass
@abstractmethod @abstractmethod
@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled.""" """Prefix cache hit rate. -1 means not supported or disabled."""
pass pass
@abstractmethod
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
device: Device = Device.GPU,
) -> List[int]:
pass

View File

@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
pass pass
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
"""No prefix caching here => return empty list
"""
return []
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill. """Determine blocks that can be skipped in prefill.
@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float: def get_prefix_cache_hit_rate(self) -> float:
return -1 return -1
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []
class NaiveBlock(Block): class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix """An implementation of the Block class that does not support prefix

View File

@ -1,13 +1,18 @@
"""Token blocks.""" """Token blocks."""
import sys
from bisect import bisect_left
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
Tuple)
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import (BlockPool, NaiveBlock, from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator) NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.sequence import Sequence
PrefixHash = int PrefixHash = int
@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else: else:
return block_id in self.evictor return block_id in self.evictor
def get_computed_block_ids(self,
prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool = True) -> List[int]:
prev_prefix_size = len(prev_computed_block_ids)
cur_size = len(block_ids)
if skip_last_block_id:
cur_size -= 1
# Sanity checks
assert cur_size >= 0
assert prev_prefix_size <= cur_size
ret = prev_computed_block_ids
for i in range(prev_prefix_size, cur_size):
block_id = block_ids[i]
if self.block_is_computed(block_id):
ret.append(block_id)
return ret
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group. """Return the block ids that are common for a given sequence group.
@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block.block_id = block_id # Assign block_id block.block_id = block_id # Assign block_id
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
"""
Given a list of block hashes, return the prefix of the block hashes that
are all cached.
Since a block's block hash includes the hashes of all previous blocks,
and we only allocate/deallocate blocks in the entire sequence, so if a
block is cached, then all previous blocks are also cached. With this
property, we can use binary search to find the prefix of cached blocks.
Args:
block_hashes (List[int]): The list of block hashes.
Returns:
List[int]: The prefix of the `block_hashes` that are cached.
"""
def _block_is_cached(block_hash: PrefixHash) -> bool:
if block_hash not in self._cached_blocks:
return False
cached_block_id = self._cached_blocks[block_hash]
# We only consider the blocks that are marked as computed.
return self.block_is_computed(cached_block_id)
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
# python <= 3.10 don't have the key argument
if sys.version_info < (3, 10):
a = [key(e) for e in a]
return bisect_left(a, x)
else:
return bisect_left(a, x, key=key)
# Look for the first block that's not cached, and returns the prefix
# i.e. blocks that are cached.
idx = _bisect_left(block_hashes,
True,
key=lambda x: not _block_is_cached(x))
return block_hashes[:idx]
class PrefixCachingBlock(Block): class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching. """A block implementation that supports prefix caching.
@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
class ComputedBlocksTracker: class ComputedBlocksTracker:
"""Handles caching of per-sequence computed block ids. """
When a sequence appears for the first time, it traverses all of the Tracks the computed blocks for each sequence.
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there Internally, it maintains a map from sequence id to the list of block hashes
is a "gap" in the computed prefix. For example, if we have blocks = for the sequence. We cache the hashes of the full blocks for each sequence,
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then and make sure the hash is calculated in the same way as the allocator.
we won't try to add more computed blocks to [1,2,3] in this sequence When a sequence is being decoded, we also update the sequence's hash
iteration, and will add more computed blocks only after the sequence is accordingly and incrementally.
freed and reused again.
Note that currently, for a given sequence, we also skip the last From the sequence hash, with prefix caching enabled, we could also calculate
block id for caching purposes, to avoid caching of a full sequence the number of cached tokens for the sequence by looking up the number of
cached block hashes in the allocator.
""" """
def __init__(self, allocator): def __init__(
self,
allocator: DeviceAwareBlockAllocator,
block_size: int,
enable_caching: bool,
):
self._allocator = allocator self._allocator = allocator
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], self._block_size = block_size
bool]] = {} self._enable_caching = enable_caching
def add_seq(self, seq_id: int) -> None: # A map from seq_id to the list of block hashes for the
"""Start tracking seq_id # sequence. This is so that we don't have to recompute the block hashes
""" # for the sequence when we need to check if the sequence is cached.
assert seq_id not in self._cached_computed_seq_blocks # Note a block that's not full will not have its hash calculated and
self._cached_computed_seq_blocks[seq_id] = ([], False) # recorded.
self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {}
# A map from seq_id to the number of tokens that are cached for the
# sequence.
# We need this so that a sequence in continuous prefill doesn't
# accidentally see its cached token count change. See comments in
# `get_num_cached_tokens` for more details.
self._seq_id_to_num_tokens_computed: Dict[int, int] = {}
def _update_seq_hashes(self, seq: Sequence) -> None:
"""Incrementally update the sequence's block hashes and record them."""
assert self._enable_caching
block_hashes_recorded = self._seq_id_to_blocks_hashes.get(
seq.seq_id, [])
cur_num_blocks_recorded = len(block_hashes_recorded)
token_ids = seq.get_token_ids()
assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, (
f"The sequence has {len(token_ids)} tokens, but"
f" already recorded {cur_num_blocks_recorded} blocks. "
"This should not happen since we assume blocks are "
"only appended other than recomputation. When the sequence is "
"recomputed, we should have removed the info of the old blocks.")
# Update the computed block hashes for the sequence. Since only full
# blocks are considered as "computed", we take floor here.
num_computed_blocks = len(token_ids) // self._block_size
# We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across
# sequences of prefixes.
prev_block_hash = (None if cur_num_blocks_recorded == 0 else
block_hashes_recorded[-1])
# Only update the computed block hashes for the new blocks
for i in range(cur_num_blocks_recorded, num_computed_blocks):
assert len(token_ids) >= (i + 1) * self._block_size
block_token_ids = token_ids[i * self._block_size:(i + 1) *
self._block_size]
# This has to be kept in sync with the allocator's hash
# calculation.
block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None,
prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids,
)
block_hashes_recorded.append(block_hash)
prev_block_hash = block_hash
self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded
def get_num_cached_tokens(self, seq: Sequence) -> int:
if not self._enable_caching:
return 0
# We always try to update the sequence hashes on the fly.
# This is to ensure that we don't miss any cached tokens for the
# sequence during decode.
# This routine should only update hash for any new blocks too.
self._update_seq_hashes(seq)
num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get(
seq.seq_id, None)
# TODO(rickyx): This hack could be removed once we mark blocks as
# computed correctly with chunked prefills.
if num_computed_tokens_prev is not None and seq.is_prefill():
# For a sequence that is still in prefill, we don't
# recompute the number of cached tokens.
# This also handles correctly chunked prefill since currently
# we mark blocks as computed even if the sequence is still partially
# prefilled. So a continuously prefilled sequence should not
# see its cached token count change while running.
return num_computed_tokens_prev
block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id]
# This is O(logN), where N is the number of blocks.
num_cached_blocks = len(
self._allocator.find_cached_blocks_prefix(block_hashes))
num_cached_tokens = num_cached_blocks * self._block_size
self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens
return num_cached_tokens
def remove_seq(self, seq_id: int) -> None: def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id """Stop tracking the sequence."""
""" if not self._enable_caching:
assert seq_id in self._cached_computed_seq_blocks return
del self._cached_computed_seq_blocks[seq_id] assert seq_id in self._seq_id_to_blocks_hashes
del self._seq_id_to_blocks_hashes[seq_id]
def get_cached_computed_blocks_and_update( assert seq_id in self._seq_id_to_num_tokens_computed
self, seq_id: int, block_ids: List[int]) -> List[int]: del self._seq_id_to_num_tokens_computed[seq_id]
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert seq_id in self._cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
seq_id]
if has_gap:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks = len(block_ids) - 1
assert num_cur_blocks >= 0
if len(prev_computed_block_ids) >= num_cur_blocks:
# Cache HIT
assert len(prev_computed_block_ids) == num_cur_blocks
return prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
prev_computed_block_ids,
block_ids,
skip_last_block_id=
True, # We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
has_gap = len(computed_block_ids) < num_cur_blocks
# Record
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
has_gap)
return computed_block_ids
class LastAccessBlocksTracker: class LastAccessBlocksTracker:

View File

@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
self._computed_blocks_tracker = ComputedBlocksTracker( self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator) self.block_allocator, self.block_size, self.enable_caching)
self._last_access_blocks_tracker = LastAccessBlocksTracker( self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator) self.block_allocator)
@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[seq.seq_id] = block_table self.block_tables[seq.seq_id] = block_table
# Track seq # Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence. # Assign the block table for each sequence.
@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[seq.seq_id] = block_table.fork() self.block_tables[seq.seq_id] = block_table.fork()
# Track seq # Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence # Allocate cross-attention block table for encoder sequence
@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
""" """
computed_seq_block_ids = [] computed_seq_block_ids = []
for seq in seqs: for seq in seqs:
computed_seq_block_ids.append( all_blocks = self.block_tables[seq.seq_id].physical_block_ids
self._computed_blocks_tracker. num_cached_tokens = (
get_cached_computed_blocks_and_update( self._computed_blocks_tracker.get_num_cached_tokens(seq))
seq.seq_id, assert num_cached_tokens % self.block_size == 0
self.block_tables[seq.seq_id].physical_block_ids)) num_cached_blocks = num_cached_tokens // self.block_size
computed_block_ids = all_blocks[:num_cached_blocks]
computed_seq_block_ids.append(computed_block_ids)
# NOTE(sang): This assumes seq_block_ids doesn't contain any None. # NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids( return self.block_allocator.get_common_computed_block_ids(
@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[child_seq.seq_id] = src_block_table.fork() self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq # Track child seq
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id) self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup, def can_swap_in(self, seq_group: SequenceGroup,
@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
return AllocStatus.OK return AllocStatus.OK
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def get_num_cached_tokens(self, seq: Sequence) -> int:
"""Get the number of tokens in blocks that are already computed and
cached in the block manager for the sequence.
"""
return self._computed_blocks_tracker.get_num_cached_tokens(seq)

View File

@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled.""" """Prefix cache hit rate. -1 means not supported or disabled."""
pass pass
@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass

View File

@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1 return -1
def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0

View File

@ -56,11 +56,16 @@ class SchedulingBudget:
max_num_seqs: int max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
# Number of cached tokens in the batch.
_num_cached_tokens: int = 0
# Number of actual non-cached tokens in the batch.
_num_batched_tokens: int = 0 _num_batched_tokens: int = 0
_num_curr_seqs: int = 0 _num_curr_seqs: int = 0
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
assert num_new_tokens != 0 # We allow num_new_tokens to be 0 when the entire sequence has
# been cached.
assert num_new_tokens >= 0
assert num_new_seqs != 0 assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
@ -68,12 +73,18 @@ class SchedulingBudget:
def remaining_token_budget(self): def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): def add_num_batched_tokens(self,
req_id: str,
num_batched_tokens: int,
num_cached_tokens: int = 0):
if req_id in self._request_ids_num_batched_tokens: if req_id in self._request_ids_num_batched_tokens:
return return
assert num_cached_tokens >= 0
assert num_batched_tokens >= 0
self._request_ids_num_batched_tokens.add(req_id) self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens self._num_batched_tokens += num_batched_tokens
self._num_cached_tokens += num_cached_tokens
def subtract_num_batched_tokens(self, req_id: str, def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int): num_batched_tokens: int):
@ -101,6 +112,10 @@ class SchedulingBudget:
def num_curr_seqs(self): def num_curr_seqs(self):
return self._num_curr_seqs return self._num_curr_seqs
@property
def num_cached_tokens(self):
return self._num_cached_tokens
@dataclass @dataclass
class ScheduledSequenceGroup: class ScheduledSequenceGroup:
@ -541,9 +556,19 @@ class Scheduler:
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
seq_group = running_queue[0] seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens( # We discard the cached tokens info here because we don't need it
seq_group, SequenceStatus.RUNNING, enable_chunking, budget) # for running sequence:
# 1. If a sequence is running with chunked prefill, the cached
# tokens info was already used for the first prefill.
# 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is
# irrelevant.
num_uncached_new_tokens, _ = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking,
budget))
num_running_tokens = num_uncached_new_tokens
if num_running_tokens == 0: if num_running_tokens == 0:
# No budget => Stop # No budget => Stop
break break
@ -715,13 +740,15 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs() num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group, num_new_tokens_uncached, num_new_tokens_cached = (
SequenceStatus.SWAPPED, self._get_num_new_uncached_and_cached_tokens(
enable_chunking, budget) seq_group, SequenceStatus.SWAPPED, enable_chunking,
budget))
if (num_new_tokens == 0 if num_new_tokens_uncached == 0 or not budget.can_schedule(
or not budget.can_schedule(num_new_tokens=num_new_tokens, num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs)): num_new_seqs=num_new_seqs,
):
break break
if lora_int_id > 0 and curr_loras is not None: if lora_int_id > 0 and curr_loras is not None:
@ -732,12 +759,19 @@ class Scheduler:
is_prefill = seq_group.is_prefill() is_prefill = seq_group.is_prefill()
if is_prefill: if is_prefill:
prefill_seq_groups.append( prefill_seq_groups.append(
ScheduledSequenceGroup(seq_group, ScheduledSequenceGroup(
token_chunk_size=num_new_tokens)) seq_group,
token_chunk_size=num_new_tokens_uncached +
num_new_tokens_cached,
))
else: else:
decode_seq_groups.append( decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1)) ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_batched_tokens(
seq_group.request_id,
num_batched_tokens=num_new_tokens_uncached,
num_cached_tokens=num_new_tokens_cached,
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs) budget.add_num_seqs(seq_group.request_id, num_new_seqs)
swapped_queue.extendleft(leftover_swapped) swapped_queue.extendleft(leftover_swapped)
@ -803,26 +837,30 @@ class Scheduler:
if waiting_queue: if waiting_queue:
seq_group = waiting_queue.popleft() seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs() num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group, num_new_tokens_uncached, _ = (
SequenceStatus.WAITING, self._get_num_new_uncached_and_cached_tokens(
False, budget) seq_group, SequenceStatus.WAITING, False, budget))
#Only preempt if priority inversion exists #Only preempt if priority inversion exists
while running_queue and self._get_priority( while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group): running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated #Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group) can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens and can_allocate == AllocStatus.OK if (num_new_tokens_uncached > 0
and budget.can_schedule(num_new_tokens=num_new_tokens, and can_allocate == AllocStatus.OK
num_new_seqs=num_new_seqs)): and budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
)):
break break
#Adjust budget to remove the victim sequence group #Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop() vseq_group = running_queue.pop()
num_running_tokens = self._get_num_new_tokens( num_running_tokens_uncached, _ = (
vseq_group, SequenceStatus.RUNNING, False, budget) self._get_num_new_uncached_and_cached_tokens(
budget.subtract_num_batched_tokens(vseq_group.request_id, vseq_group, SequenceStatus.RUNNING, False, budget))
num_running_tokens) budget.subtract_num_batched_tokens(
vseq_group.request_id, num_running_tokens_uncached)
num_running_seqs = vseq_group.get_max_num_running_seqs() num_running_seqs = vseq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(vseq_group.request_id, budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs) num_running_seqs)
@ -882,9 +920,12 @@ class Scheduler:
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt " "Waiting sequence group should have only one prompt "
"sequence.") "sequence.")
num_new_tokens = self._get_num_new_tokens(seq_group, num_new_tokens_uncached, num_new_tokens_cached = (
SequenceStatus.WAITING, self._get_num_new_uncached_and_cached_tokens(
enable_chunking, budget) seq_group, SequenceStatus.WAITING, enable_chunking,
budget))
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
if not enable_chunking: if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len() num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens assert num_new_tokens == num_prompt_tokens
@ -935,10 +976,18 @@ class Scheduler:
waiting_queue.popleft() waiting_queue.popleft()
continue continue
if (budget.num_batched_tokens >=
self.scheduler_config.max_num_batched_tokens):
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
break
num_new_seqs = seq_group.get_max_num_running_seqs() num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_new_tokens == 0 if num_new_tokens_uncached == 0 or not budget.can_schedule(
or not budget.can_schedule(num_new_tokens=num_new_tokens, num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs)): num_new_seqs=num_new_seqs,
):
break break
# Can schedule this request. # Can schedule this request.
@ -967,7 +1016,11 @@ class Scheduler:
seq_groups.append( seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group, ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens)) token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_batched_tokens(
seq_group.request_id,
num_batched_tokens=num_new_tokens_uncached,
num_cached_tokens=num_new_tokens_cached,
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs) budget.add_num_seqs(seq_group.request_id, num_new_seqs)
# Queue requests that couldn't be scheduled. # Queue requests that couldn't be scheduled.
@ -1075,7 +1128,8 @@ class Scheduler:
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=scheduled_seq_groups, scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups, num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens, num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
@ -1119,7 +1173,6 @@ class Scheduler:
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras) swapped_in = self._schedule_swapped(budget, curr_loras)
# Schedule new prefills.
prefills = self._schedule_prefills(budget, prefills = self._schedule_prefills(budget,
curr_loras, curr_loras,
enable_chunking=True) enable_chunking=True)
@ -1157,7 +1210,8 @@ class Scheduler:
num_prefill_groups=(len(prefills.seq_groups) + num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) + len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)), len(running_scheduled.prefill_seq_groups)),
num_batched_tokens=budget.num_batched_tokens, num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy + blocks_to_copy=running_scheduled.blocks_to_copy +
@ -1584,30 +1638,140 @@ class Scheduler:
return self.scheduler_config.num_lookahead_slots return self.scheduler_config.num_lookahead_slots
def _get_num_new_tokens(self, seq_group: SequenceGroup, def _get_num_new_uncached_and_cached_tokens(
status: SequenceStatus, enable_chunking: bool, self,
budget: SchedulingBudget) -> int: seq_group: SequenceGroup,
"""Get the next new tokens to compute for a given sequence group status: SequenceStatus,
that's in a given `status`. enable_chunking: bool,
budget: SchedulingBudget,
) -> Tuple[int, int]:
"""
Returns the number of new uncached and cached tokens to schedule for a
given sequence group that's in a given `status`.
The API could chunk the number of tokens to compute based on `budget` The API could chunk the number of tokens to compute based on `budget`
if `enable_chunking` is True. If a sequence group has multiple if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen. phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget. Returns (0, 0) if the new token cannot be computed due to token budget.
The cached tokens's blocks are already computed, and the attention
backend will reuse the cached blocks rather than recomputing them. So
the scheduler could schedule these cached tokens "for free".
Args:
seq_group: The sequence group to get the number of new tokens to
schedule.
status: The status of the sequences to get the number of new tokens
to schedule.
enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute.
Returns:
A tuple of two ints. The first int is the number of new uncached
tokens to schedule. The second int is the number of cached tokens.
If no more new tokens can be scheduled, returns (0, 0).
""" """
num_new_tokens = 0 num_cached_new_tokens = 0
num_uncached_new_tokens = 0
seqs = seq_group.get_seqs(status=status) seqs = seq_group.get_seqs(status=status)
# Compute the number of new uncached and cached tokens for
# each sequence.
for seq in seqs: for seq in seqs:
num_new_tokens += seq.get_num_new_tokens() if not seq.is_prefill():
assert num_new_tokens > 0 # Decode sequences should always just have 1 uncached token
# TODO(rickyx): Actually is this still correct for multi-step?
num_uncached_new_tokens += 1
continue
num_computed_tokens_seq = seq.get_num_computed_tokens()
all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
if not self.cache_config.enable_prefix_caching:
# If prefix caching is not enabled, all new tokens are uncached.
num_uncached_new_tokens += all_num_new_tokens_seq
continue
# NOTE: the cache token might be currently in a block that's in an
# evictor meaning that it's not yet allocated. However, we don't
# exclude such tokens in the cache count because it will be
# guaranteed to be allocated later if the sequence can be allocated.
num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
seq)
# Sanity check.
if num_cached_tokens_seq < num_computed_tokens_seq:
# This should only happen with chunked prefill, and
# the seq is still in prefill. The `num_cached_tokens_seq`
# is the value we calculated on scheduling the first prefill.
# For subsequent continuous prefill steps, we cached the
# number of cache tokens for the sequence so the cached token
# count could be less than the number of computed tokens.
# See comments on `ComputedBlocksTracker` for more details.
assert (
seq.is_prefill() and seq.status == SequenceStatus.RUNNING
and self.scheduler_config.chunked_prefill_enabled
), ("Number of cached tokens should not be less than the "
"number of computed tokens for a sequence that's still "
f"in prefill. But there are {num_cached_tokens_seq} cached "
f"tokens and {num_computed_tokens_seq} computed tokens "
f"for sequence {seq.seq_id}.")
num_cached_new_tokens_seq = max(
0, num_cached_tokens_seq - num_computed_tokens_seq)
num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
num_cached_new_tokens_seq)
num_uncached_new_tokens += num_uncached_new_tokens_seq
num_cached_new_tokens += num_cached_new_tokens_seq
if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
# For a fully cached hit sequence, we actually need to recompute the
# last token. So we need at least 1 uncached token to schedule.
# See ModelRunner._compute_for_prefix_cache_hit for more details.
num_uncached_new_tokens = 1
num_cached_new_tokens -= 1
if enable_chunking and len(seqs) == 1:
# Chunk if a running request cannot fit in the given budget. # Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search # If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk. # in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1: num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
self.scheduler_config,
self.cache_config,
budget,
self._get_prompt_limit(seq_group),
num_uncached_new_tokens,
)
return num_uncached_new_tokens, num_cached_new_tokens
@staticmethod
def _chunk_new_tokens_to_schedule(
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
budget: SchedulingBudget,
prompt_limit: int,
num_new_tokens: int,
) -> int:
"""
Chunks the number of new tokens to schedule based on the budget when
chunked prefill is enabled.
Args:
scheduler_config: The scheduler config.
cache_config: The cache config.
budget: The budget to chunk the number of tokens to compute.
prompt_limit: The maximum number of tokens allowed in a prompt.
num_new_tokens: The number of new tokens to schedule.
Returns:
The number of new tokens to schedule after chunking.
"""
remaining_token_budget = budget.remaining_token_budget() remaining_token_budget = budget.remaining_token_budget()
if self.scheduler_config.is_multi_step: if scheduler_config.is_multi_step:
# The current multi-step + chunked prefill capability does # The current multi-step + chunked prefill capability does
# not actually support chunking prompts. # not actually support chunking prompts.
# #
@ -1617,20 +1781,23 @@ class Scheduler:
# #
# Prompts with more tokens than the current remaining budget # Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps # are postponed to future scheduler steps
if num_new_tokens > self._get_prompt_limit(seq_group): if num_new_tokens > prompt_limit:
# If the seq_group is in prompt-stage, pass the # If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore # num_new_tokens as-is so the caller can ignore
# the sequence. # the sequence.
pass return num_new_tokens
else:
num_new_tokens = 0 \ return (0 if num_new_tokens > remaining_token_budget else
if num_new_tokens > remaining_token_budget \ num_new_tokens)
else num_new_tokens
elif self.cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
# Adjust the remaining token budget to be divisible by the block
# size when prefix caching is enabled.
# When prefix caching is enabled, we always allocate # When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block # the number of new tokens that is dividable by the block
# size to avoid partial block matching. # size to avoid partial block matching.
block_size = self.cache_config.block_size block_size = cache_config.block_size
remainder = budget.token_budget % block_size remainder = budget.token_budget % block_size
if remainder != 0: if remainder != 0:
raise ValueError("When enabling chunked prefill and " raise ValueError("When enabling chunked prefill and "
@ -1639,9 +1806,10 @@ class Scheduler:
"block size, but got chunk_size " "block size, but got chunk_size "
f"({budget.token_budget}) % block_size " f"({budget.token_budget}) % block_size "
f"({block_size}) = {remainder}") f"({block_size}) = {remainder}")
if remaining_token_budget < num_new_tokens: # Round down to block size.
num_new_tokens = (remaining_token_budget // remaining_token_budget = (remaining_token_budget // block_size *
block_size) * block_size block_size)
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget) num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens return num_new_tokens

View File

@ -579,6 +579,9 @@ class Sequence:
return 1 return 1
return self.data.get_num_uncomputed_tokens() return self.data.get_num_uncomputed_tokens()
def get_num_computed_tokens(self) -> int:
return self.data.get_num_computed_tokens()
def is_prefill(self) -> bool: def is_prefill(self) -> bool:
return self.data.stage == SequenceStage.PREFILL return self.data.stage == SequenceStage.PREFILL