Prefix Cache Aware Scheduling [1/n] (#10128)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
parent
7c25fe45a6
commit
4634a89d18
@ -5,9 +5,14 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.core.utils import create_dummy_sequence
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
|
||||
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||
PrefixCachingBlock,
|
||||
PrefixCachingBlockAllocator)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class TestPrefixCachingBlock:
|
||||
@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
|
||||
token_ids=common_token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
block_ids = [block.block_id for block in blocks]
|
||||
block_hashes = [block.content_hash for block in blocks]
|
||||
# The allocated blocks should be marked as touched
|
||||
# but not computed.
|
||||
computed_block_ids = allocator.get_computed_block_ids(
|
||||
[], block_ids, skip_last_block_id=False)
|
||||
computed_block_ids = allocator.find_cached_blocks_prefix(
|
||||
block_hashes)
|
||||
assert len(computed_block_ids) == 0
|
||||
|
||||
allocator.mark_blocks_as_computed([])
|
||||
computed_block_ids = allocator.get_computed_block_ids(
|
||||
[], block_ids, skip_last_block_id=False)
|
||||
computed_block_ids = allocator.find_cached_blocks_prefix(
|
||||
block_hashes=block_hashes)
|
||||
assert len(computed_block_ids) == common_blocks
|
||||
|
||||
@staticmethod
|
||||
def test_find_cached_blocks_prefix():
|
||||
"""
|
||||
This test verifies the behavior of find_cached_blocks_prefix.
|
||||
"""
|
||||
block_size = 4
|
||||
num_blocks = 8
|
||||
total_test_blocks = 12
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
token_ids = list(range(total_test_blocks * block_size))
|
||||
block_tokens_seq1 = token_ids[:num_blocks * block_size]
|
||||
blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=block_tokens_seq1,
|
||||
allocator=allocator,
|
||||
)
|
||||
block_hashes_seq1 = [block.content_hash for block in blocks_seq1]
|
||||
allocator.mark_blocks_as_computed([])
|
||||
|
||||
# All blocks should be cached.
|
||||
cached_blocks_seq1 = allocator.find_cached_blocks_prefix(
|
||||
block_hashes=block_hashes_seq1)
|
||||
assert len(cached_blocks_seq1) == num_blocks
|
||||
|
||||
# Free the first sequence.
|
||||
for block in blocks_seq1:
|
||||
allocator.free(block)
|
||||
|
||||
# All blocks should be still be cached if not required to be allocated.
|
||||
cached_blocks = allocator.find_cached_blocks_prefix(
|
||||
block_hashes=block_hashes_seq1)
|
||||
assert len(cached_blocks) == num_blocks
|
||||
|
||||
block_tokens_seq2 = token_ids[num_blocks * block_size:]
|
||||
blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=block_tokens_seq2,
|
||||
allocator=allocator,
|
||||
)
|
||||
block_hashes_seq2 = [block.content_hash for block in blocks_seq2]
|
||||
allocator.mark_blocks_as_computed([])
|
||||
cached_blocks = allocator.find_cached_blocks_prefix(
|
||||
block_hashes=block_hashes_seq2)
|
||||
assert len(cached_blocks) == len(blocks_seq2)
|
||||
|
||||
# Half of the blocks from seq1 should still be cached.
|
||||
num_evicted_blocks = len(blocks_seq2)
|
||||
cached_blocks = allocator.find_cached_blocks_prefix(
|
||||
block_hashes=block_hashes_seq1)
|
||||
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
class TestComputedBlocksTracker:
|
||||
|
||||
@staticmethod
|
||||
def _get_mock_allocator():
|
||||
return MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
@staticmethod
|
||||
def test_get_num_cached_tokens():
|
||||
"""
|
||||
Test it correctly computes the number of cached tokens for a given
|
||||
sequence:
|
||||
|
||||
- The cache token count is derived from the number of cached blocks.
|
||||
- The cache token count is updated when the allocator is updated.
|
||||
- When a sequence is removed, the cache token count should be updated
|
||||
accordingly.
|
||||
|
||||
# TODO(rickyx): This behaviour for prefill sequence is a hack until
|
||||
we fix the computed blocks tracking.
|
||||
- The cache token count for prefill sequence doesn't change while
|
||||
the sequence is in continuous prefill (chunked prefill).
|
||||
"""
|
||||
block_size = 4
|
||||
mock_allocator = TestComputedBlocksTracker._get_mock_allocator()
|
||||
tracker = ComputedBlocksTracker(
|
||||
allocator=mock_allocator,
|
||||
block_size=block_size,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
# Not yet allocated.
|
||||
tokens = [0, 1, 2, 3, 4, 5]
|
||||
seq1 = create_dummy_sequence(request_id=0,
|
||||
token_ids=tokens,
|
||||
block_size=block_size)
|
||||
mock_allocator.find_cached_blocks_prefix.return_value = []
|
||||
assert tracker.get_num_cached_tokens(seq1) == 0
|
||||
|
||||
mock_allocator.find_cached_blocks_prefix.return_value = [
|
||||
None
|
||||
] # 1 block cached.
|
||||
# Result is cached for prefill sequence.
|
||||
assert tracker.get_num_cached_tokens(seq1) == 0
|
||||
|
||||
# Mark the sequence as non-prefill.
|
||||
seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed.
|
||||
assert not seq1.is_prefill()
|
||||
|
||||
# Recomputes for decoding sequence.
|
||||
assert tracker.get_num_cached_tokens(seq1) == 4
|
||||
|
||||
# Append new tokens to the sequence.
|
||||
num_new_tokens = 3
|
||||
for i in range(num_new_tokens):
|
||||
seq1.append_token_id(i, {i: Logprob(logprob=0.0)})
|
||||
|
||||
assert tracker.get_num_cached_tokens(seq1) == 4
|
||||
|
||||
# Update the allocator.
|
||||
mock_allocator.find_cached_blocks_prefix.return_value = [
|
||||
None
|
||||
] * 2 # 2 blocks cached.
|
||||
assert tracker.get_num_cached_tokens(seq1) == 8
|
||||
|
||||
# Remove the sequence.
|
||||
tracker.remove_seq(seq1.seq_id)
|
||||
|
||||
# Re-create the sequence with the same request id to simulate recompute.
|
||||
seq1 = create_dummy_sequence(request_id=0,
|
||||
token_ids=tokens,
|
||||
block_size=block_size)
|
||||
mock_allocator.find_cached_blocks_prefix.return_value = [
|
||||
] # no cached block
|
||||
assert tracker.get_num_cached_tokens(seq1) == 0
|
||||
|
||||
@staticmethod
|
||||
def test_correct_block_hash():
|
||||
"""
|
||||
Test that the block hash is correctly computed for a sequence (should
|
||||
match the underlying block allocator's block hash). So the number of
|
||||
cached tokens is correctly retrieved.
|
||||
"""
|
||||
block_size = 4
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching",
|
||||
num_gpu_blocks=16,
|
||||
num_cpu_blocks=16,
|
||||
block_size=block_size,
|
||||
)
|
||||
gpu_allocator = allocator._allocators[Device.GPU]
|
||||
|
||||
tracker = ComputedBlocksTracker(
|
||||
allocator=allocator,
|
||||
block_size=block_size,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
tokens = list(range(block_size * 4)) # 4 blocks.
|
||||
seq = create_dummy_sequence(request_id=0,
|
||||
token_ids=tokens,
|
||||
block_size=block_size)
|
||||
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=tokens,
|
||||
allocator=gpu_allocator,
|
||||
)
|
||||
allocator.mark_blocks_as_computed([])
|
||||
|
||||
assert tracker.get_num_cached_tokens(seq) == len(tokens)
|
||||
|
@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
from .utils import (append_new_token, append_new_token_seq_group,
|
||||
create_dummy_prompt, get_sequence_groups,
|
||||
schedule_and_update_computed_tokens)
|
||||
from .utils import (append_new_token, append_new_token_seq,
|
||||
append_new_token_seq_group, create_dummy_prompt,
|
||||
get_sequence_groups, schedule_and_update_computed_tokens)
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
@ -305,6 +305,8 @@ def initialize_scheduler(
|
||||
block_size=4,
|
||||
num_cpu_blocks=8,
|
||||
num_gpu_blocks=8,
|
||||
enable_prefix_caching=False,
|
||||
enable_chunked_prefill=False,
|
||||
):
|
||||
block_size = block_size
|
||||
scheduler_config = SchedulerConfig(
|
||||
@ -312,8 +314,15 @@ def initialize_scheduler(
|
||||
max_num_batched_tokens=max_token_budget,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size,
|
||||
1.0,
|
||||
1,
|
||||
"auto",
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
@ -800,3 +809,165 @@ def test_scheduling_budget():
|
||||
assert budget.num_curr_seqs == 0
|
||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||
def test_prefix_caching_aware_prefills(enable_prefix_caching):
|
||||
"""
|
||||
Test the below scenario:
|
||||
|
||||
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
|
||||
|
||||
The test verifies the below scenarios:
|
||||
1. SeqA is first scheduled.
|
||||
2. SeqB and SeqC can be prefilled together in a single schedule round
|
||||
even though there are not enough token budgets to prefill both without
|
||||
considering prefix caching.
|
||||
"""
|
||||
|
||||
block_size = 4
|
||||
max_num_batched_tokens = 12
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_token_budget=max_num_batched_tokens,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
seqA_tokens = list(range(8))
|
||||
num_shared_tokens = 4
|
||||
seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
|
||||
12, 16)) # Shared prefix first 4.
|
||||
seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
|
||||
16, 20)) # Shared prefix first 4.
|
||||
|
||||
seqA, seqA_group = create_dummy_prompt("0",
|
||||
prompt_tokens=seqA_tokens,
|
||||
block_size=block_size)
|
||||
seqB, seqB_group = create_dummy_prompt("1",
|
||||
prompt_tokens=seqB_tokens,
|
||||
block_size=block_size)
|
||||
seqC, seqC_group = create_dummy_prompt("2",
|
||||
prompt_tokens=seqC_tokens,
|
||||
block_size=block_size)
|
||||
|
||||
# Schedule seqA prefill.
|
||||
scheduler.add_seq_group(seqA_group)
|
||||
metas, out, _ = scheduler.schedule()
|
||||
assert (len(out.scheduled_seq_groups) == 1
|
||||
and out.scheduled_seq_groups[0].seq_group == seqA_group)
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)
|
||||
|
||||
# Schedule seqA decode.
|
||||
append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
|
||||
metas, out, _ = scheduler.schedule()
|
||||
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == 1
|
||||
|
||||
# Schedule seqB and seqC prefills should work with prefix caching.
|
||||
scheduler.add_seq_group(seqB_group)
|
||||
scheduler.add_seq_group(seqC_group)
|
||||
metas, out, _ = scheduler.schedule()
|
||||
|
||||
if enable_prefix_caching:
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert set([
|
||||
out.scheduled_seq_groups[0].seq_group,
|
||||
out.scheduled_seq_groups[1].seq_group,
|
||||
]) == set([seqB_group, seqC_group])
|
||||
assert len(metas) == 2
|
||||
for meta in metas:
|
||||
assert meta.token_chunk_size == 8
|
||||
assert (len(meta.computed_block_nums) == num_shared_tokens //
|
||||
block_size) # 1 Block for the 8 tokens.
|
||||
else:
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert len(metas) == 1
|
||||
assert metas[0].token_chunk_size == 8
|
||||
assert len(metas[0].computed_block_nums) == 0 # No blocks computed.
|
||||
|
||||
|
||||
def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
||||
):
|
||||
"""
|
||||
This test verifies that we don't schedule new prefills if there's already
|
||||
a continuous prefill in progress even though the new prefills with shared
|
||||
prefix can fit in the token budget:
|
||||
|
||||
- SeqA is being chunked prefill.
|
||||
- SeqB with the same prompt shouldn't be scheduled for prefill even though
|
||||
there's enough token budget to prefill the cached tokens.
|
||||
- Neither should seqC be scheduled.
|
||||
|
||||
- When seqA is in decoding phase, seqB and seqC can be scheduled.
|
||||
- Entire seqB should be prefilled since it's a full prefix cache hit.
|
||||
- SeqC would be partially prefilled with the prefix shared, and the
|
||||
remaining unique tokens would be prefilled (rounded down to be
|
||||
block-size aligned).
|
||||
"""
|
||||
|
||||
block_size = 2
|
||||
max_num_batched_tokens = 4
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_token_budget=max_num_batched_tokens,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=100,
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
seqA_tokens = list(range(8))
|
||||
seqB_tokens = seqA_tokens
|
||||
seqC_shared_prefix_len = 4
|
||||
seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))
|
||||
|
||||
seqA, seqA_group = create_dummy_prompt("0",
|
||||
prompt_tokens=seqA_tokens,
|
||||
block_size=block_size)
|
||||
seqB, seqB_group = create_dummy_prompt("1",
|
||||
prompt_tokens=seqB_tokens,
|
||||
block_size=block_size)
|
||||
|
||||
# Chunked prefill seqA.
|
||||
scheduler.add_seq_group(seqA_group)
|
||||
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == 4
|
||||
|
||||
# seqB should not be scheduled with ongoing prefills.
|
||||
scheduler.add_seq_group(seqB_group)
|
||||
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == 4
|
||||
|
||||
# both seqB and seqC can now be scheduled with seqA is over.
|
||||
# seqA is in decoding phase.
|
||||
append_new_token_seq(seqA, 999)
|
||||
seqC, seqC_group = create_dummy_prompt("2",
|
||||
prompt_tokens=seqC_tokens,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seqC_group)
|
||||
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
|
||||
metas = {meta.request_id: meta for meta in metas}
|
||||
assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode
|
||||
assert (metas[seqB_group.request_id].token_chunk_size == 8
|
||||
) # Fully cached prefill
|
||||
assert (
|
||||
metas[seqC_group.request_id].token_chunk_size == 6
|
||||
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||
|
@ -1,17 +1,20 @@
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
request_id: str,
|
||||
prompt_length: int,
|
||||
prompt_length: int = -1,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
best_of: int = 1,
|
||||
@ -26,6 +29,7 @@ def create_dummy_prompt(
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
@ -42,6 +46,15 @@ def create_dummy_prompt(
|
||||
return prompt, seq_group
|
||||
|
||||
|
||||
def create_dummy_sequence(request_id: int, token_ids: List[int],
|
||||
block_size: int) -> Sequence:
|
||||
return Sequence(
|
||||
seq_id=request_id,
|
||||
inputs=token_inputs(token_ids),
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_prompt_encoder_decoder(
|
||||
request_id: str,
|
||||
decoder_prompt_length: int,
|
||||
@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out, _ = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
for s in out.scheduled_seq_groups:
|
||||
s.seq_group.update_num_computed_tokens(s.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def append_new_token_seq(seq: Sequence, token_id: int):
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
class SchedulerProxy:
|
||||
"""
|
||||
A proxy class to forward calls to the scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler: Scheduler):
|
||||
self.scheduler_ = scheduler
|
||||
self.call_history: Dict[str, List[Any]] = defaultdict(list)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
result = getattr(self.scheduler_, name)(*args, **kwargs)
|
||||
self.call_history[name].append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def last_schedule_ret(
|
||||
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
|
||||
_, _, ret = self.call_history["schedule"][-1]
|
||||
return ret
|
||||
|
@ -2,10 +2,15 @@
|
||||
|
||||
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from tests.core.utils import SchedulerProxy, create_dummy_prompt
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm import SamplingParams, TokensPrompt
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
|
||||
from ..models.utils import check_outputs_equal
|
||||
|
||||
@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("cached_position", [0, 1])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_mixed_requests(
|
||||
hf_runner,
|
||||
@ -37,6 +43,7 @@ def test_mixed_requests(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
cached_position: int,
|
||||
enable_chunked_prefill: bool,
|
||||
block_size: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
@ -55,6 +62,7 @@ def test_mixed_requests(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
block_size=block_size,
|
||||
) as vllm_model:
|
||||
# Run the first prompt so the cache is populated
|
||||
@ -72,13 +80,13 @@ def test_mixed_requests(
|
||||
block_size) * block_size
|
||||
else:
|
||||
expected_num_cached_tokens = 0
|
||||
assert req_outputs[
|
||||
i].num_cached_tokens == expected_num_cached_tokens
|
||||
assert (
|
||||
req_outputs[i].num_cached_tokens == expected_num_cached_tokens)
|
||||
|
||||
vllm_outputs = [
|
||||
(output.prompt_token_ids + list(output.outputs[0].token_ids),
|
||||
output.prompt + output.outputs[0].text) for output in req_outputs
|
||||
]
|
||||
vllm_outputs = [(
|
||||
output.prompt_token_ids + list(output.outputs[0].token_ids),
|
||||
output.prompt + output.outputs[0].text,
|
||||
) for output in req_outputs]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
|
||||
for prompt in UNSTABLE_PROMPT_SEQUENCE:
|
||||
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
|
||||
SamplingParams(max_tokens=1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_fully_cached_prefill_needs_uncached_token(model):
|
||||
block_size = 16
|
||||
max_num_batched_tokens = 16
|
||||
num_output_tokens = 5
|
||||
# Make a vllm engine
|
||||
runner = VllmRunner(
|
||||
model_name=model,
|
||||
gpu_memory_utilization=0.7,
|
||||
enable_chunked_prefill=True,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_batched_tokens,
|
||||
)
|
||||
engine: LLMEngine = runner.model.llm_engine
|
||||
|
||||
scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore
|
||||
engine.scheduler[0] = scheduler
|
||||
|
||||
# SeqA
|
||||
seqA_tokens = list(range(2 * block_size))
|
||||
seqA, seq_groupA = create_dummy_prompt(
|
||||
request_id="0",
|
||||
prompt_tokens=seqA_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
scheduler.add_seq_group(seq_groupA)
|
||||
|
||||
assert seqA.data.get_num_computed_tokens() == 0
|
||||
|
||||
# Prefill seqA
|
||||
while not seqA.is_finished():
|
||||
engine.step()
|
||||
|
||||
# seqB
|
||||
seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1
|
||||
seqB, seq_groupB = create_dummy_prompt(
|
||||
request_id="1",
|
||||
prompt_tokens=seqB_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
# seqC is the same as seqA
|
||||
seqC, seq_groupC = create_dummy_prompt(
|
||||
request_id="2",
|
||||
prompt_tokens=seqA_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
scheduler.add_seq_group(seq_groupB)
|
||||
scheduler.add_seq_group(seq_groupC)
|
||||
|
||||
# Even seqC is fully cached, it should not be prefilled since we
|
||||
# require at least 1 uncached token.
|
||||
engine.step()
|
||||
|
||||
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
|
||||
assert len(sched_out.scheduled_seq_groups) == 1
|
||||
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
|
||||
seq_groupB.request_id)
|
||||
assert (sched_out.scheduled_seq_groups[0].token_chunk_size ==
|
||||
max_num_batched_tokens)
|
||||
|
||||
# When seqB is finished, seqC could be prefilled.
|
||||
while not seqB.is_finished():
|
||||
engine.step()
|
||||
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
|
||||
assert len(sched_out.scheduled_seq_groups) == 1
|
||||
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
|
||||
seq_groupB.request_id)
|
||||
|
||||
engine.step()
|
||||
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
|
||||
assert len(sched_out.scheduled_seq_groups) == 1
|
||||
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
|
||||
seq_groupC.request_id)
|
||||
assert sched_out.scheduled_seq_groups[0].token_chunk_size == len(
|
||||
seqA_tokens)
|
||||
|
@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||
|
||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
||||
block_ids: List[int],
|
||||
skip_last_block_id: bool) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].get_computed_block_ids(
|
||||
prev_computed_block_ids, block_ids, skip_last_block_id)
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
self._swap_mapping.clear()
|
||||
return list(mapping.items())
|
||||
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
device: Device = Device.GPU,
|
||||
) -> List[int]:
|
||||
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
|
||||
|
||||
|
||||
class NullBlock(Block):
|
||||
"""
|
||||
|
@ -159,12 +159,6 @@ class BlockAllocator(ABC):
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
||||
block_ids: List[int],
|
||||
skip_last_block_id: bool) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
@ -192,6 +186,13 @@ class BlockAllocator(ABC):
|
||||
class NoFreeBlocksError(ValueError):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
) -> List[int]:
|
||||
pass
|
||||
|
||||
|
||||
class DeviceAwareBlockAllocator(ABC):
|
||||
|
||||
@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable_blocks(self, prev_block: Optional[Block],
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
device: Device) -> List[Block]:
|
||||
device: Device,
|
||||
) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
||||
block_ids: List[int],
|
||||
skip_last_block_id: bool) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
device: Device = Device.GPU,
|
||||
) -> List[int]:
|
||||
pass
|
||||
|
@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
|
||||
block_ids: List[int],
|
||||
skip_last_block_id: bool) -> List[int]:
|
||||
"""No prefix caching here => return empty list
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Determine blocks that can be skipped in prefill.
|
||||
@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
# Not applicable for naive block allocator.
|
||||
return []
|
||||
|
||||
|
||||
class NaiveBlock(Block):
|
||||
"""An implementation of the Block class that does not support prefix
|
||||
|
@ -1,13 +1,18 @@
|
||||
"""Token blocks."""
|
||||
import sys
|
||||
from bisect import bisect_left
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
|
||||
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
|
||||
Tuple)
|
||||
|
||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
||||
NaiveBlockAllocator)
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.sequence import Sequence
|
||||
|
||||
PrefixHash = int
|
||||
|
||||
@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
else:
|
||||
return block_id in self.evictor
|
||||
|
||||
def get_computed_block_ids(self,
|
||||
prev_computed_block_ids: List[int],
|
||||
block_ids: List[int],
|
||||
skip_last_block_id: bool = True) -> List[int]:
|
||||
prev_prefix_size = len(prev_computed_block_ids)
|
||||
cur_size = len(block_ids)
|
||||
if skip_last_block_id:
|
||||
cur_size -= 1
|
||||
|
||||
# Sanity checks
|
||||
assert cur_size >= 0
|
||||
assert prev_prefix_size <= cur_size
|
||||
|
||||
ret = prev_computed_block_ids
|
||||
for i in range(prev_prefix_size, cur_size):
|
||||
block_id = block_ids[i]
|
||||
if self.block_is_computed(block_id):
|
||||
ret.append(block_id)
|
||||
return ret
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
"""
|
||||
Given a list of block hashes, return the prefix of the block hashes that
|
||||
are all cached.
|
||||
|
||||
Since a block's block hash includes the hashes of all previous blocks,
|
||||
and we only allocate/deallocate blocks in the entire sequence, so if a
|
||||
block is cached, then all previous blocks are also cached. With this
|
||||
property, we can use binary search to find the prefix of cached blocks.
|
||||
|
||||
Args:
|
||||
block_hashes (List[int]): The list of block hashes.
|
||||
|
||||
Returns:
|
||||
List[int]: The prefix of the `block_hashes` that are cached.
|
||||
"""
|
||||
|
||||
def _block_is_cached(block_hash: PrefixHash) -> bool:
|
||||
if block_hash not in self._cached_blocks:
|
||||
return False
|
||||
|
||||
cached_block_id = self._cached_blocks[block_hash]
|
||||
# We only consider the blocks that are marked as computed.
|
||||
return self.block_is_computed(cached_block_id)
|
||||
|
||||
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
|
||||
|
||||
# python <= 3.10 don't have the key argument
|
||||
if sys.version_info < (3, 10):
|
||||
a = [key(e) for e in a]
|
||||
return bisect_left(a, x)
|
||||
else:
|
||||
return bisect_left(a, x, key=key)
|
||||
|
||||
# Look for the first block that's not cached, and returns the prefix
|
||||
# i.e. blocks that are cached.
|
||||
idx = _bisect_left(block_hashes,
|
||||
True,
|
||||
key=lambda x: not _block_is_cached(x))
|
||||
return block_hashes[:idx]
|
||||
|
||||
|
||||
class PrefixCachingBlock(Block):
|
||||
"""A block implementation that supports prefix caching.
|
||||
@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
|
||||
|
||||
|
||||
class ComputedBlocksTracker:
|
||||
"""Handles caching of per-sequence computed block ids.
|
||||
When a sequence appears for the first time, it traverses all of the
|
||||
blocks and detects the prefix of blocks that is computed. On the
|
||||
subsequent times, it only traverses the new blocks that were added
|
||||
and updates the already recorded prefix of blocks with the newly
|
||||
computed blocks.
|
||||
"""
|
||||
Tracks the computed blocks for each sequence.
|
||||
|
||||
To avoid redundant traversals, the algorithm also detects when there
|
||||
is a "gap" in the computed prefix. For example, if we have blocks =
|
||||
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
|
||||
we won't try to add more computed blocks to [1,2,3] in this sequence
|
||||
iteration, and will add more computed blocks only after the sequence is
|
||||
freed and reused again.
|
||||
Internally, it maintains a map from sequence id to the list of block hashes
|
||||
for the sequence. We cache the hashes of the full blocks for each sequence,
|
||||
and make sure the hash is calculated in the same way as the allocator.
|
||||
When a sequence is being decoded, we also update the sequence's hash
|
||||
accordingly and incrementally.
|
||||
|
||||
Note that currently, for a given sequence, we also skip the last
|
||||
block id for caching purposes, to avoid caching of a full sequence
|
||||
From the sequence hash, with prefix caching enabled, we could also calculate
|
||||
the number of cached tokens for the sequence by looking up the number of
|
||||
cached block hashes in the allocator.
|
||||
"""
|
||||
|
||||
def __init__(self, allocator):
|
||||
def __init__(
|
||||
self,
|
||||
allocator: DeviceAwareBlockAllocator,
|
||||
block_size: int,
|
||||
enable_caching: bool,
|
||||
):
|
||||
self._allocator = allocator
|
||||
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
|
||||
bool]] = {}
|
||||
self._block_size = block_size
|
||||
self._enable_caching = enable_caching
|
||||
|
||||
def add_seq(self, seq_id: int) -> None:
|
||||
"""Start tracking seq_id
|
||||
"""
|
||||
assert seq_id not in self._cached_computed_seq_blocks
|
||||
self._cached_computed_seq_blocks[seq_id] = ([], False)
|
||||
# A map from seq_id to the list of block hashes for the
|
||||
# sequence. This is so that we don't have to recompute the block hashes
|
||||
# for the sequence when we need to check if the sequence is cached.
|
||||
# Note a block that's not full will not have its hash calculated and
|
||||
# recorded.
|
||||
self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {}
|
||||
|
||||
# A map from seq_id to the number of tokens that are cached for the
|
||||
# sequence.
|
||||
# We need this so that a sequence in continuous prefill doesn't
|
||||
# accidentally see its cached token count change. See comments in
|
||||
# `get_num_cached_tokens` for more details.
|
||||
self._seq_id_to_num_tokens_computed: Dict[int, int] = {}
|
||||
|
||||
def _update_seq_hashes(self, seq: Sequence) -> None:
|
||||
"""Incrementally update the sequence's block hashes and record them."""
|
||||
assert self._enable_caching
|
||||
|
||||
block_hashes_recorded = self._seq_id_to_blocks_hashes.get(
|
||||
seq.seq_id, [])
|
||||
cur_num_blocks_recorded = len(block_hashes_recorded)
|
||||
token_ids = seq.get_token_ids()
|
||||
assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, (
|
||||
f"The sequence has {len(token_ids)} tokens, but"
|
||||
f" already recorded {cur_num_blocks_recorded} blocks. "
|
||||
"This should not happen since we assume blocks are "
|
||||
"only appended other than recomputation. When the sequence is "
|
||||
"recomputed, we should have removed the info of the old blocks.")
|
||||
# Update the computed block hashes for the sequence. Since only full
|
||||
# blocks are considered as "computed", we take floor here.
|
||||
num_computed_blocks = len(token_ids) // self._block_size
|
||||
|
||||
# We need to know the hash of the previous block to compute the hash of
|
||||
# the current block so that blocks could be uniquely identified across
|
||||
# sequences of prefixes.
|
||||
prev_block_hash = (None if cur_num_blocks_recorded == 0 else
|
||||
block_hashes_recorded[-1])
|
||||
# Only update the computed block hashes for the new blocks
|
||||
for i in range(cur_num_blocks_recorded, num_computed_blocks):
|
||||
assert len(token_ids) >= (i + 1) * self._block_size
|
||||
block_token_ids = token_ids[i * self._block_size:(i + 1) *
|
||||
self._block_size]
|
||||
# This has to be kept in sync with the allocator's hash
|
||||
# calculation.
|
||||
block_hash = PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block=prev_block_hash is None,
|
||||
prev_block_hash=prev_block_hash,
|
||||
cur_block_token_ids=block_token_ids,
|
||||
)
|
||||
block_hashes_recorded.append(block_hash)
|
||||
prev_block_hash = block_hash
|
||||
|
||||
self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
if not self._enable_caching:
|
||||
return 0
|
||||
|
||||
# We always try to update the sequence hashes on the fly.
|
||||
# This is to ensure that we don't miss any cached tokens for the
|
||||
# sequence during decode.
|
||||
# This routine should only update hash for any new blocks too.
|
||||
self._update_seq_hashes(seq)
|
||||
|
||||
num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get(
|
||||
seq.seq_id, None)
|
||||
|
||||
# TODO(rickyx): This hack could be removed once we mark blocks as
|
||||
# computed correctly with chunked prefills.
|
||||
if num_computed_tokens_prev is not None and seq.is_prefill():
|
||||
# For a sequence that is still in prefill, we don't
|
||||
# recompute the number of cached tokens.
|
||||
# This also handles correctly chunked prefill since currently
|
||||
# we mark blocks as computed even if the sequence is still partially
|
||||
# prefilled. So a continuously prefilled sequence should not
|
||||
# see its cached token count change while running.
|
||||
return num_computed_tokens_prev
|
||||
|
||||
block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id]
|
||||
|
||||
# This is O(logN), where N is the number of blocks.
|
||||
num_cached_blocks = len(
|
||||
self._allocator.find_cached_blocks_prefix(block_hashes))
|
||||
num_cached_tokens = num_cached_blocks * self._block_size
|
||||
self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens
|
||||
return num_cached_tokens
|
||||
|
||||
def remove_seq(self, seq_id: int) -> None:
|
||||
"""Stop tracking seq_id
|
||||
"""
|
||||
assert seq_id in self._cached_computed_seq_blocks
|
||||
del self._cached_computed_seq_blocks[seq_id]
|
||||
"""Stop tracking the sequence."""
|
||||
if not self._enable_caching:
|
||||
return
|
||||
assert seq_id in self._seq_id_to_blocks_hashes
|
||||
del self._seq_id_to_blocks_hashes[seq_id]
|
||||
|
||||
def get_cached_computed_blocks_and_update(
|
||||
self, seq_id: int, block_ids: List[int]) -> List[int]:
|
||||
""" Look at the class documentation for details
|
||||
"""
|
||||
# Ensure seq_id is already tracked
|
||||
assert seq_id in self._cached_computed_seq_blocks
|
||||
|
||||
# Get cached data (may be empty on the first time)
|
||||
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
|
||||
seq_id]
|
||||
|
||||
if has_gap:
|
||||
# When gap is detected, we do not add more computed blocks at this
|
||||
# sequence iteration
|
||||
return prev_computed_block_ids
|
||||
|
||||
# We do not consider the last block id for caching purposes.
|
||||
num_cur_blocks = len(block_ids) - 1
|
||||
assert num_cur_blocks >= 0
|
||||
|
||||
if len(prev_computed_block_ids) >= num_cur_blocks:
|
||||
# Cache HIT
|
||||
assert len(prev_computed_block_ids) == num_cur_blocks
|
||||
return prev_computed_block_ids
|
||||
|
||||
# If here, then we may possibly add more computed blocks. As a result,
|
||||
# traverse the additional blocks after prev_computed_block_ids to
|
||||
# detect more computed blocks and add them.
|
||||
|
||||
# Incremental init for seq_id => Look only at the new blocks
|
||||
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
|
||||
prev_computed_block_ids,
|
||||
block_ids,
|
||||
skip_last_block_id=
|
||||
True, # We skip last block id to avoid caching of full seq
|
||||
)
|
||||
|
||||
# Detect if there is a "gap"
|
||||
has_gap = len(computed_block_ids) < num_cur_blocks
|
||||
|
||||
# Record
|
||||
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
|
||||
has_gap)
|
||||
|
||||
return computed_block_ids
|
||||
assert seq_id in self._seq_id_to_num_tokens_computed
|
||||
del self._seq_id_to_num_tokens_computed[seq_id]
|
||||
|
||||
|
||||
class LastAccessBlocksTracker:
|
||||
|
@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
||||
|
||||
self._computed_blocks_tracker = ComputedBlocksTracker(
|
||||
self.block_allocator)
|
||||
self.block_allocator, self.block_size, self.enable_caching)
|
||||
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
||||
self.block_allocator)
|
||||
|
||||
@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
# Track seq
|
||||
self._computed_blocks_tracker.add_seq(seq.seq_id)
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
# Track seq
|
||||
self._computed_blocks_tracker.add_seq(seq.seq_id)
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Allocate cross-attention block table for encoder sequence
|
||||
@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
"""
|
||||
computed_seq_block_ids = []
|
||||
for seq in seqs:
|
||||
computed_seq_block_ids.append(
|
||||
self._computed_blocks_tracker.
|
||||
get_cached_computed_blocks_and_update(
|
||||
seq.seq_id,
|
||||
self.block_tables[seq.seq_id].physical_block_ids))
|
||||
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
|
||||
num_cached_tokens = (
|
||||
self._computed_blocks_tracker.get_num_cached_tokens(seq))
|
||||
assert num_cached_tokens % self.block_size == 0
|
||||
num_cached_blocks = num_cached_tokens // self.block_size
|
||||
computed_block_ids = all_blocks[:num_cached_blocks]
|
||||
computed_seq_block_ids.append(computed_block_ids)
|
||||
|
||||
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
||||
return self.block_allocator.get_common_computed_block_ids(
|
||||
@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
# Track child seq
|
||||
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
|
||||
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
"""Get the number of tokens in blocks that are already computed and
|
||||
cached in the block manager for the sequence.
|
||||
"""
|
||||
return self._computed_blocks_tracker.get_num_cached_tokens(seq)
|
||||
|
@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
pass
|
||||
|
@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return -1
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
return 0
|
||||
|
@ -56,11 +56,16 @@ class SchedulingBudget:
|
||||
max_num_seqs: int
|
||||
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
||||
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
|
||||
# Number of cached tokens in the batch.
|
||||
_num_cached_tokens: int = 0
|
||||
# Number of actual non-cached tokens in the batch.
|
||||
_num_batched_tokens: int = 0
|
||||
_num_curr_seqs: int = 0
|
||||
|
||||
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
|
||||
assert num_new_tokens != 0
|
||||
# We allow num_new_tokens to be 0 when the entire sequence has
|
||||
# been cached.
|
||||
assert num_new_tokens >= 0
|
||||
assert num_new_seqs != 0
|
||||
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
|
||||
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
|
||||
@ -68,12 +73,18 @@ class SchedulingBudget:
|
||||
def remaining_token_budget(self):
|
||||
return self.token_budget - self.num_batched_tokens
|
||||
|
||||
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
|
||||
def add_num_batched_tokens(self,
|
||||
req_id: str,
|
||||
num_batched_tokens: int,
|
||||
num_cached_tokens: int = 0):
|
||||
if req_id in self._request_ids_num_batched_tokens:
|
||||
return
|
||||
assert num_cached_tokens >= 0
|
||||
assert num_batched_tokens >= 0
|
||||
|
||||
self._request_ids_num_batched_tokens.add(req_id)
|
||||
self._num_batched_tokens += num_batched_tokens
|
||||
self._num_cached_tokens += num_cached_tokens
|
||||
|
||||
def subtract_num_batched_tokens(self, req_id: str,
|
||||
num_batched_tokens: int):
|
||||
@ -101,6 +112,10 @@ class SchedulingBudget:
|
||||
def num_curr_seqs(self):
|
||||
return self._num_curr_seqs
|
||||
|
||||
@property
|
||||
def num_cached_tokens(self):
|
||||
return self._num_cached_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduledSequenceGroup:
|
||||
@ -541,9 +556,19 @@ class Scheduler:
|
||||
assert len(self._async_stopped) == 0
|
||||
while running_queue:
|
||||
seq_group = running_queue[0]
|
||||
num_running_tokens = self._get_num_new_tokens(
|
||||
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
||||
# We discard the cached tokens info here because we don't need it
|
||||
# for running sequence:
|
||||
# 1. If a sequence is running with chunked prefill, the cached
|
||||
# tokens info was already used for the first prefill.
|
||||
# 2. If a sequence is running with non-chunked prefill, then
|
||||
# there it's a decoding sequence, and the cached tokens info is
|
||||
# irrelevant.
|
||||
num_uncached_new_tokens, _ = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.RUNNING, enable_chunking,
|
||||
budget))
|
||||
|
||||
num_running_tokens = num_uncached_new_tokens
|
||||
if num_running_tokens == 0:
|
||||
# No budget => Stop
|
||||
break
|
||||
@ -715,13 +740,15 @@ class Scheduler:
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
num_new_tokens = self._get_num_new_tokens(seq_group,
|
||||
SequenceStatus.SWAPPED,
|
||||
enable_chunking, budget)
|
||||
num_new_tokens_uncached, num_new_tokens_cached = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.SWAPPED, enable_chunking,
|
||||
budget))
|
||||
|
||||
if (num_new_tokens == 0
|
||||
or not budget.can_schedule(num_new_tokens=num_new_tokens,
|
||||
num_new_seqs=num_new_seqs)):
|
||||
if num_new_tokens_uncached == 0 or not budget.can_schedule(
|
||||
num_new_tokens=num_new_tokens_uncached,
|
||||
num_new_seqs=num_new_seqs,
|
||||
):
|
||||
break
|
||||
|
||||
if lora_int_id > 0 and curr_loras is not None:
|
||||
@ -732,12 +759,19 @@ class Scheduler:
|
||||
is_prefill = seq_group.is_prefill()
|
||||
if is_prefill:
|
||||
prefill_seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group,
|
||||
token_chunk_size=num_new_tokens))
|
||||
ScheduledSequenceGroup(
|
||||
seq_group,
|
||||
token_chunk_size=num_new_tokens_uncached +
|
||||
num_new_tokens_cached,
|
||||
))
|
||||
else:
|
||||
decode_seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
|
||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
||||
budget.add_num_batched_tokens(
|
||||
seq_group.request_id,
|
||||
num_batched_tokens=num_new_tokens_uncached,
|
||||
num_cached_tokens=num_new_tokens_cached,
|
||||
)
|
||||
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
||||
|
||||
swapped_queue.extendleft(leftover_swapped)
|
||||
@ -803,26 +837,30 @@ class Scheduler:
|
||||
if waiting_queue:
|
||||
seq_group = waiting_queue.popleft()
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
num_new_tokens = self._get_num_new_tokens(seq_group,
|
||||
SequenceStatus.WAITING,
|
||||
False, budget)
|
||||
num_new_tokens_uncached, _ = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.WAITING, False, budget))
|
||||
|
||||
#Only preempt if priority inversion exists
|
||||
while running_queue and self._get_priority(
|
||||
running_queue[-1]) > self._get_priority(seq_group):
|
||||
#Only preempt if waiting sequence cannot be allocated
|
||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||
if (num_new_tokens and can_allocate == AllocStatus.OK
|
||||
and budget.can_schedule(num_new_tokens=num_new_tokens,
|
||||
num_new_seqs=num_new_seqs)):
|
||||
if (num_new_tokens_uncached > 0
|
||||
and can_allocate == AllocStatus.OK
|
||||
and budget.can_schedule(
|
||||
num_new_tokens=num_new_tokens_uncached,
|
||||
num_new_seqs=num_new_seqs,
|
||||
)):
|
||||
break
|
||||
|
||||
#Adjust budget to remove the victim sequence group
|
||||
vseq_group = running_queue.pop()
|
||||
num_running_tokens = self._get_num_new_tokens(
|
||||
vseq_group, SequenceStatus.RUNNING, False, budget)
|
||||
budget.subtract_num_batched_tokens(vseq_group.request_id,
|
||||
num_running_tokens)
|
||||
num_running_tokens_uncached, _ = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
vseq_group, SequenceStatus.RUNNING, False, budget))
|
||||
budget.subtract_num_batched_tokens(
|
||||
vseq_group.request_id, num_running_tokens_uncached)
|
||||
num_running_seqs = vseq_group.get_max_num_running_seqs()
|
||||
budget.subtract_num_seqs(vseq_group.request_id,
|
||||
num_running_seqs)
|
||||
@ -882,9 +920,12 @@ class Scheduler:
|
||||
assert len(waiting_seqs) == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
num_new_tokens = self._get_num_new_tokens(seq_group,
|
||||
SequenceStatus.WAITING,
|
||||
enable_chunking, budget)
|
||||
num_new_tokens_uncached, num_new_tokens_cached = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.WAITING, enable_chunking,
|
||||
budget))
|
||||
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
|
||||
|
||||
if not enable_chunking:
|
||||
num_prompt_tokens = waiting_seqs[0].get_len()
|
||||
assert num_new_tokens == num_prompt_tokens
|
||||
@ -935,10 +976,18 @@ class Scheduler:
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
|
||||
if (budget.num_batched_tokens >=
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
# We've reached the budget limit - since there might be
|
||||
# continuous prefills in the running queue, we should break
|
||||
# to avoid scheduling any new prefills.
|
||||
break
|
||||
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_new_tokens == 0
|
||||
or not budget.can_schedule(num_new_tokens=num_new_tokens,
|
||||
num_new_seqs=num_new_seqs)):
|
||||
if num_new_tokens_uncached == 0 or not budget.can_schedule(
|
||||
num_new_tokens=num_new_tokens_uncached,
|
||||
num_new_seqs=num_new_seqs,
|
||||
):
|
||||
break
|
||||
|
||||
# Can schedule this request.
|
||||
@ -967,7 +1016,11 @@ class Scheduler:
|
||||
seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group=seq_group,
|
||||
token_chunk_size=num_new_tokens))
|
||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
||||
budget.add_num_batched_tokens(
|
||||
seq_group.request_id,
|
||||
num_batched_tokens=num_new_tokens_uncached,
|
||||
num_cached_tokens=num_new_tokens_cached,
|
||||
)
|
||||
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
|
||||
|
||||
# Queue requests that couldn't be scheduled.
|
||||
@ -1075,7 +1128,8 @@ class Scheduler:
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled_seq_groups,
|
||||
num_prefill_groups=num_prefill_groups,
|
||||
num_batched_tokens=budget.num_batched_tokens,
|
||||
num_batched_tokens=budget.num_batched_tokens +
|
||||
budget.num_cached_tokens,
|
||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
@ -1119,7 +1173,6 @@ class Scheduler:
|
||||
running_scheduled.swapped_out) == 0:
|
||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||
|
||||
# Schedule new prefills.
|
||||
prefills = self._schedule_prefills(budget,
|
||||
curr_loras,
|
||||
enable_chunking=True)
|
||||
@ -1157,7 +1210,8 @@ class Scheduler:
|
||||
num_prefill_groups=(len(prefills.seq_groups) +
|
||||
len(swapped_in.prefill_seq_groups) +
|
||||
len(running_scheduled.prefill_seq_groups)),
|
||||
num_batched_tokens=budget.num_batched_tokens,
|
||||
num_batched_tokens=budget.num_batched_tokens +
|
||||
budget.num_cached_tokens,
|
||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||
blocks_to_copy=running_scheduled.blocks_to_copy +
|
||||
@ -1584,30 +1638,140 @@ class Scheduler:
|
||||
|
||||
return self.scheduler_config.num_lookahead_slots
|
||||
|
||||
def _get_num_new_tokens(self, seq_group: SequenceGroup,
|
||||
status: SequenceStatus, enable_chunking: bool,
|
||||
budget: SchedulingBudget) -> int:
|
||||
"""Get the next new tokens to compute for a given sequence group
|
||||
that's in a given `status`.
|
||||
def _get_num_new_uncached_and_cached_tokens(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
status: SequenceStatus,
|
||||
enable_chunking: bool,
|
||||
budget: SchedulingBudget,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the number of new uncached and cached tokens to schedule for a
|
||||
given sequence group that's in a given `status`.
|
||||
|
||||
The API could chunk the number of tokens to compute based on `budget`
|
||||
if `enable_chunking` is True. If a sequence group has multiple
|
||||
sequences (e.g., running beam search), it means it is in decoding
|
||||
phase, so chunking doesn't happen.
|
||||
|
||||
Returns 0 if the new token cannot be computed due to token budget.
|
||||
Returns (0, 0) if the new token cannot be computed due to token budget.
|
||||
|
||||
The cached tokens's blocks are already computed, and the attention
|
||||
backend will reuse the cached blocks rather than recomputing them. So
|
||||
the scheduler could schedule these cached tokens "for free".
|
||||
|
||||
Args:
|
||||
seq_group: The sequence group to get the number of new tokens to
|
||||
schedule.
|
||||
status: The status of the sequences to get the number of new tokens
|
||||
to schedule.
|
||||
enable_chunking: Whether to chunk the number of tokens to compute.
|
||||
budget: The budget to chunk the number of tokens to compute.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple of two ints. The first int is the number of new uncached
|
||||
tokens to schedule. The second int is the number of cached tokens.
|
||||
If no more new tokens can be scheduled, returns (0, 0).
|
||||
"""
|
||||
num_new_tokens = 0
|
||||
num_cached_new_tokens = 0
|
||||
num_uncached_new_tokens = 0
|
||||
|
||||
seqs = seq_group.get_seqs(status=status)
|
||||
# Compute the number of new uncached and cached tokens for
|
||||
# each sequence.
|
||||
for seq in seqs:
|
||||
num_new_tokens += seq.get_num_new_tokens()
|
||||
assert num_new_tokens > 0
|
||||
if not seq.is_prefill():
|
||||
# Decode sequences should always just have 1 uncached token
|
||||
# TODO(rickyx): Actually is this still correct for multi-step?
|
||||
num_uncached_new_tokens += 1
|
||||
continue
|
||||
|
||||
num_computed_tokens_seq = seq.get_num_computed_tokens()
|
||||
all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
# If prefix caching is not enabled, all new tokens are uncached.
|
||||
num_uncached_new_tokens += all_num_new_tokens_seq
|
||||
continue
|
||||
|
||||
# NOTE: the cache token might be currently in a block that's in an
|
||||
# evictor meaning that it's not yet allocated. However, we don't
|
||||
# exclude such tokens in the cache count because it will be
|
||||
# guaranteed to be allocated later if the sequence can be allocated.
|
||||
num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
|
||||
seq)
|
||||
|
||||
# Sanity check.
|
||||
if num_cached_tokens_seq < num_computed_tokens_seq:
|
||||
# This should only happen with chunked prefill, and
|
||||
# the seq is still in prefill. The `num_cached_tokens_seq`
|
||||
# is the value we calculated on scheduling the first prefill.
|
||||
# For subsequent continuous prefill steps, we cached the
|
||||
# number of cache tokens for the sequence so the cached token
|
||||
# count could be less than the number of computed tokens.
|
||||
# See comments on `ComputedBlocksTracker` for more details.
|
||||
assert (
|
||||
seq.is_prefill() and seq.status == SequenceStatus.RUNNING
|
||||
and self.scheduler_config.chunked_prefill_enabled
|
||||
), ("Number of cached tokens should not be less than the "
|
||||
"number of computed tokens for a sequence that's still "
|
||||
f"in prefill. But there are {num_cached_tokens_seq} cached "
|
||||
f"tokens and {num_computed_tokens_seq} computed tokens "
|
||||
f"for sequence {seq.seq_id}.")
|
||||
|
||||
num_cached_new_tokens_seq = max(
|
||||
0, num_cached_tokens_seq - num_computed_tokens_seq)
|
||||
num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
|
||||
num_cached_new_tokens_seq)
|
||||
|
||||
num_uncached_new_tokens += num_uncached_new_tokens_seq
|
||||
num_cached_new_tokens += num_cached_new_tokens_seq
|
||||
|
||||
if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
|
||||
# For a fully cached hit sequence, we actually need to recompute the
|
||||
# last token. So we need at least 1 uncached token to schedule.
|
||||
# See ModelRunner._compute_for_prefix_cache_hit for more details.
|
||||
num_uncached_new_tokens = 1
|
||||
num_cached_new_tokens -= 1
|
||||
|
||||
if enable_chunking and len(seqs) == 1:
|
||||
# Chunk if a running request cannot fit in the given budget.
|
||||
# If number of seq > 1, it means it is doing beam search
|
||||
# in a decode phase. Do not chunk.
|
||||
if enable_chunking and len(seqs) == 1:
|
||||
num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
|
||||
self.scheduler_config,
|
||||
self.cache_config,
|
||||
budget,
|
||||
self._get_prompt_limit(seq_group),
|
||||
num_uncached_new_tokens,
|
||||
)
|
||||
|
||||
return num_uncached_new_tokens, num_cached_new_tokens
|
||||
|
||||
@staticmethod
|
||||
def _chunk_new_tokens_to_schedule(
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
budget: SchedulingBudget,
|
||||
prompt_limit: int,
|
||||
num_new_tokens: int,
|
||||
) -> int:
|
||||
"""
|
||||
Chunks the number of new tokens to schedule based on the budget when
|
||||
chunked prefill is enabled.
|
||||
|
||||
Args:
|
||||
scheduler_config: The scheduler config.
|
||||
cache_config: The cache config.
|
||||
budget: The budget to chunk the number of tokens to compute.
|
||||
prompt_limit: The maximum number of tokens allowed in a prompt.
|
||||
num_new_tokens: The number of new tokens to schedule.
|
||||
|
||||
Returns:
|
||||
The number of new tokens to schedule after chunking.
|
||||
"""
|
||||
remaining_token_budget = budget.remaining_token_budget()
|
||||
if self.scheduler_config.is_multi_step:
|
||||
if scheduler_config.is_multi_step:
|
||||
# The current multi-step + chunked prefill capability does
|
||||
# not actually support chunking prompts.
|
||||
#
|
||||
@ -1617,20 +1781,23 @@ class Scheduler:
|
||||
#
|
||||
# Prompts with more tokens than the current remaining budget
|
||||
# are postponed to future scheduler steps
|
||||
if num_new_tokens > self._get_prompt_limit(seq_group):
|
||||
if num_new_tokens > prompt_limit:
|
||||
# If the seq_group is in prompt-stage, pass the
|
||||
# num_new_tokens as-is so the caller can ignore
|
||||
# the sequence.
|
||||
pass
|
||||
else:
|
||||
num_new_tokens = 0 \
|
||||
if num_new_tokens > remaining_token_budget \
|
||||
else num_new_tokens
|
||||
elif self.cache_config.enable_prefix_caching:
|
||||
return num_new_tokens
|
||||
|
||||
return (0 if num_new_tokens > remaining_token_budget else
|
||||
num_new_tokens)
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# Adjust the remaining token budget to be divisible by the block
|
||||
# size when prefix caching is enabled.
|
||||
|
||||
# When prefix caching is enabled, we always allocate
|
||||
# the number of new tokens that is dividable by the block
|
||||
# size to avoid partial block matching.
|
||||
block_size = self.cache_config.block_size
|
||||
block_size = cache_config.block_size
|
||||
remainder = budget.token_budget % block_size
|
||||
if remainder != 0:
|
||||
raise ValueError("When enabling chunked prefill and "
|
||||
@ -1639,9 +1806,10 @@ class Scheduler:
|
||||
"block size, but got chunk_size "
|
||||
f"({budget.token_budget}) % block_size "
|
||||
f"({block_size}) = {remainder}")
|
||||
if remaining_token_budget < num_new_tokens:
|
||||
num_new_tokens = (remaining_token_budget //
|
||||
block_size) * block_size
|
||||
else:
|
||||
# Round down to block size.
|
||||
remaining_token_budget = (remaining_token_budget // block_size *
|
||||
block_size)
|
||||
|
||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
||||
|
||||
return num_new_tokens
|
||||
|
@ -579,6 +579,9 @@ class Sequence:
|
||||
return 1
|
||||
return self.data.get_num_uncomputed_tokens()
|
||||
|
||||
def get_num_computed_tokens(self) -> int:
|
||||
return self.data.get_num_computed_tokens()
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
return self.data.stage == SequenceStage.PREFILL
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user