[Performance] Enable chunked prefill and prefix caching together (#7753)
This commit is contained in:
parent
f508e03e7f
commit
e3580537a4
@ -6,6 +6,7 @@ prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
@ -156,3 +157,68 @@ def test_models_with_fp8_kv_cache(
|
||||
name_0="no_chunked_prefill",
|
||||
name_1="chunked_prefill",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_with_prefix_caching(
|
||||
vllm_runner,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
chunk_size: int,
|
||||
use_v2_block_manager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
Checks exact match decode with and without prefix caching
|
||||
with chunked prefill enabled.
|
||||
"""
|
||||
model = "meta-llama/Llama-2-7b-chat-hf"
|
||||
# The common prompt has 142 tokens with Llama-2 tokenizer.
|
||||
common_prompt = "You are a helpful AI assistant " * 20
|
||||
unique_prompts = [
|
||||
"Question", # Warmup
|
||||
"Question", # Fully cached
|
||||
"Another question", # Partial cached
|
||||
]
|
||||
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]
|
||||
|
||||
max_num_batched_tokens = max_num_seqs = chunk_size
|
||||
outputs = {} # type: ignore
|
||||
check_result = True
|
||||
for enable in (True, False):
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
# It should fail when prefix caching is enable and chunk
|
||||
# size is not a multiple of block size (16).
|
||||
should_fail = chunk_size % 16 != 0 and enable
|
||||
check_result &= not should_fail
|
||||
outputs[enable] = []
|
||||
# Send the request one-by-one to ensure the cache is populated.
|
||||
with pytest.raises(ValueError) if should_fail else nullcontext():
|
||||
for prompt in full_prompts:
|
||||
outputs[enable] += vllm_model.generate_greedy([prompt],
|
||||
max_tokens)
|
||||
|
||||
# Check results only if we did not expect a failure.
|
||||
if check_result:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs[False],
|
||||
outputs_1_lst=outputs[True],
|
||||
name_0="w/o prefix caching",
|
||||
name_1="with prefix caching",
|
||||
)
|
||||
|
@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
|
||||
|
||||
# assert all blocks are free now
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
|
||||
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
|
||||
"""When prefix cache and chunked prefill are enabled, the block manager
|
||||
should only mark a chunk of blocks as computed instead of all blocks.
|
||||
"""
|
||||
|
||||
block_size = 4
|
||||
num_cpu_blocks = 0
|
||||
num_gpu_blocks = 16
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_gpu_blocks,
|
||||
num_cpu_blocks,
|
||||
watermark=0,
|
||||
enable_caching=True)
|
||||
|
||||
# Set prompt size to have num_gpu_blocks - 1 full blocks.
|
||||
prompt_length = block_size * num_gpu_blocks - 1
|
||||
|
||||
# Allocate (reserve) all blocks.
|
||||
_, seq_group = create_dummy_prompt("0",
|
||||
prompt_length,
|
||||
block_size=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
|
||||
|
||||
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
|
||||
token_chunk_size = int(block_size * 2.5)
|
||||
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
|
||||
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
|
||||
assert len(computed_blocks) == 2
|
||||
|
||||
# Actual computed tokens.
|
||||
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
# 2nd chunk: Complete 3rd block and additional 4 blocks.
|
||||
token_chunk_size = int(block_size * 4.5)
|
||||
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
|
||||
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
|
||||
assert len(computed_blocks) == 7
|
||||
|
@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
|
||||
assert len(get_sequence_groups(out)) == max_seqs
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
|
||||
|
||||
def test_perfix_caching():
|
||||
"""Verify allocating full blocks when prefix caching is enabled."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size,
|
||||
1.0,
|
||||
1,
|
||||
"auto",
|
||||
enable_prefix_caching=True)
|
||||
cache_config.num_cpu_blocks = 0
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
block_size=block_size,
|
||||
prompt_length=50)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 50
|
||||
# Verify it is chunked. Note that although the budget is 64-50=14,
|
||||
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
|
||||
# tokens are allocated.
|
||||
assert seq_group_meta[1].token_chunk_size == 12
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 62
|
||||
|
@ -681,14 +681,20 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
|
||||
def compute_full_blocks_in_seq(self, seq: Sequence):
|
||||
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return
|
||||
max_full_block = seq.get_len() // self.block_size - 1
|
||||
|
||||
# When chunked prefill is enabled, the computed full blocks
|
||||
# should be calculated based on the number of computed tokens.
|
||||
max_computed_tokens = (seq.data.get_num_computed_tokens() +
|
||||
token_chunk_size)
|
||||
computed_full_blocks = max_computed_tokens // self.block_size
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if max_full_block == -1:
|
||||
if computed_full_blocks == 0:
|
||||
return
|
||||
for i in reversed(range(max_full_block)):
|
||||
for i in reversed(range(computed_full_blocks)):
|
||||
if block_table[i].computed:
|
||||
break
|
||||
block_table[i].computed = True
|
||||
@ -718,10 +724,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.get_seqs():
|
||||
self.compute_full_blocks_in_seq(seq)
|
||||
self.compute_full_blocks_in_seq(seq, token_chunk_size)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
if device == Device.GPU:
|
||||
|
@ -290,7 +290,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
self._last_access_blocks_tracker.update_last_access(
|
||||
seq.seq_id, now)
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
# If prefix caching is enabled, mark immutable blocks as computed
|
||||
# right after they have been scheduled (for prefill). This assumes
|
||||
# the scheduler is synchronous so blocks are actually computed when
|
||||
|
@ -80,7 +80,8 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
seq_group: List[Sequence]) -> List[int]:
|
||||
return []
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
pass
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
|
@ -115,7 +115,8 @@ class BlockSpaceManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1226,7 +1226,8 @@ class Scheduler:
|
||||
# will crash the vLLM instance / will not retry.
|
||||
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
self.block_manager.mark_blocks_as_computed(
|
||||
scheduled_seq_group.seq_group)
|
||||
scheduled_seq_group.seq_group,
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
self._seq_group_metadata_cache[self.next_cache_id].reset()
|
||||
|
||||
@ -1457,10 +1458,27 @@ class Scheduler:
|
||||
for seq in seqs:
|
||||
num_new_tokens += seq.get_num_new_tokens()
|
||||
assert num_new_tokens > 0
|
||||
# Chunk if a running request cannot fit in.
|
||||
# If number of seq > 1, it means it is doing beam search in a
|
||||
# decode phase. Do not chunk in that case.
|
||||
# Chunk if a running request cannot fit in the given budget.
|
||||
# If number of seq > 1, it means it is doing beam search
|
||||
# in a decode phase. Do not chunk.
|
||||
if enable_chunking and len(seqs) == 1:
|
||||
num_new_tokens = min(num_new_tokens,
|
||||
budget.remaining_token_budget())
|
||||
remaining_token_budget = budget.remaining_token_budget()
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
# When prefix caching is enabled, we always allocate
|
||||
# the number of new tokens that is dividable by the block size
|
||||
# to avoid partial block matching.
|
||||
block_size = self.cache_config.block_size
|
||||
reminder = budget.token_budget % block_size
|
||||
if reminder != 0:
|
||||
raise ValueError("When enabling chunked prefill and "
|
||||
"prefix caching, max_num_batched_tokens "
|
||||
"(chunk size) must be dividable by "
|
||||
"block size, but got chunk_size "
|
||||
f"({budget.token_budget}) % block_size "
|
||||
f"({block_size}) = {reminder}")
|
||||
if remaining_token_budget < num_new_tokens:
|
||||
num_new_tokens = (remaining_token_budget //
|
||||
block_size) * block_size
|
||||
else:
|
||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
||||
return num_new_tokens
|
||||
|
@ -501,23 +501,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
and self.sliding_window is None
|
||||
and inter_data.is_prompt)
|
||||
inter_data.prefix_cache_hit = prefix_cache_hit
|
||||
if self.chunked_prefill_enabled and prefix_cache_hit:
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching now.")
|
||||
|
||||
# If prefix cache is hit, advance context length to bypass
|
||||
# hit blocks. Accordingly, input tokens, position and query length
|
||||
# have to be updated.
|
||||
if prefix_cache_hit:
|
||||
if not prefix_cache_hit:
|
||||
return
|
||||
|
||||
assert computed_block_nums is not None
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
# The cache hit prompt tokens in this sequence. Note that
|
||||
# this may be larger than the sequence length if chunked
|
||||
# prefill is enabled.
|
||||
prefix_cache_len = len(computed_block_nums) * self.block_size
|
||||
# The number of so far computed prompt tokens in this sequence.
|
||||
context_len = inter_data.context_lens[seq_idx]
|
||||
# The total number of prompt tokens in this sequence.
|
||||
# When chunked prefill is enabled, this is the token number of
|
||||
# computed chunks + current chunk.
|
||||
seq_len = inter_data.seq_lens[seq_idx]
|
||||
if prefix_cache_len <= context_len:
|
||||
# We already passed the cache hit region,
|
||||
# so do normal computation.
|
||||
pass
|
||||
elif context_len < prefix_cache_len < seq_len:
|
||||
# Partial hit. Compute the missing part.
|
||||
uncomputed_start = prefix_cache_len - context_len
|
||||
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
|
||||
seq_idx][context_len:]
|
||||
seq_idx][uncomputed_start:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][context_len:]
|
||||
seq_idx][uncomputed_start:]
|
||||
context_len = prefix_cache_len
|
||||
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.query_lens[
|
||||
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
|
||||
elif seq_len <= prefix_cache_len:
|
||||
# Full hit. Only compute the last token to avoid
|
||||
# erroneous behavior. FIXME: Ideally we should directly
|
||||
# mark all tokens as computed in the scheduler and do not
|
||||
# schedule this sequence, so this case should not happen.
|
||||
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
|
||||
seq_idx][-1:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][-1:]
|
||||
inter_data.query_lens[seq_idx] = 1
|
||||
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
|
||||
|
||||
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
|
||||
seq_idx: int,
|
||||
|
Loading…
x
Reference in New Issue
Block a user