[Core] Multi-Step + Single Step Prefills via Chunked Prefill code path (#8378)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-09-27 16:32:07 -04:00 committed by GitHub
parent c5d55356f9
commit c2ec430ab5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 513 additions and 108 deletions

View File

@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr[cur_query_id] = slot_num;
}
inline void verify_tensor(std::string const& name, torch::Tensor& t,
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;

View File

@ -37,6 +37,7 @@ DEFAULT_SERVER_ARGS: List[str] = [
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("is_async", [True])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
@pytest.mark.asyncio
async def test_multi_step(
example_prompts,
@ -49,6 +50,7 @@ async def test_multi_step(
is_async: bool,
num_logprobs: Optional[int],
attention_backend: str,
enable_chunked_prefill: bool,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
@ -74,6 +76,10 @@ async def test_multi_step(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
if enable_chunked_prefill and \
(pp_size > 1 or attention_backend != "FLASH_ATTN"):
pytest.skip("Multi-step with Chunked-Prefill only supports"
"PP=1 and FLASH_ATTN backend")
override_backend_env_variable(monkeypatch, attention_backend)
@ -93,6 +99,9 @@ async def test_multi_step(
if eager_mode:
ms_server_args.append("--enforce-eager")
if enable_chunked_prefill:
ms_server_args.append("--enable-chunked-prefill")
distributed_args = [
"--tensor-parallel-size",
str(tp_size),

View File

@ -16,6 +16,7 @@ NUM_PROMPTS = [10]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@ -28,6 +29,7 @@ def test_multi_step_llm(
model: str,
dtype: str,
tp_size: int,
enable_chunked_prefill: bool,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
@ -51,6 +53,7 @@ def test_multi_step_llm(
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
enable_chunked_prefill: chunked-prefill on/off
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
@ -73,6 +76,7 @@ def test_multi_step_llm(
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
enable_chunked_prefill=enable_chunked_prefill,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)

View File

@ -342,9 +342,13 @@ class FlashAttentionMetadata(AttentionMetadata):
)
return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
@ -355,6 +359,23 @@ class FlashAttentionMetadata(AttentionMetadata):
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
@ -366,7 +387,6 @@ class FlashAttentionMetadata(AttentionMetadata):
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
@ -706,8 +726,10 @@ class FlashAttentionImpl(AttentionImpl):
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]

View File

@ -410,18 +410,22 @@ class FlashInferMetadata(AttentionMetadata):
return self
def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None

View File

@ -983,9 +983,16 @@ class SchedulerConfig:
policy: str = "fcfs") -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
if num_scheduler_steps > 1:
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
# for now. Have max_num_batched_tokens set to max_model_len
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
max_num_batched_tokens = max(max_model_len, 2048)
else:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.

View File

@ -55,9 +55,12 @@ class BlockTable:
self._num_full_slots = self._get_num_token_ids()
@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
def get_num_required_blocks(token_ids: List[int],
block_size: int,
num_lookahead_slots: int = 0) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
sequence of token IDs along with any look-ahead slots that may be
required (like in multi-step + chunked-prefill).
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
@ -66,12 +69,14 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
num_lookahead_slots (int): look-ahead slots that the sequence may
require.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
sequence of token IDs along with any required look-ahead slots.
"""
return cdiv(len(token_ids), block_size)
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
def allocate(self,
token_ids: List[int],

View File

@ -281,10 +281,15 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
self_num_required_blocks = self._get_seq_num_required_blocks(

View File

@ -107,7 +107,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
@ -117,6 +119,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
num_lookahead_slots=num_lookahead_slots,
)
if seq_group.is_encoder_decoder():

View File

@ -21,7 +21,9 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
) -> None:
pass
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# Always return OK for dummy purposes
return AllocStatus.OK

View File

@ -44,7 +44,9 @@ class BlockSpaceManager(ABC):
raise ValueError(f"Unknown version {version=}")
@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
pass
@abstractmethod

View File

@ -522,7 +522,7 @@ class Scheduler:
ret.swapped_out.clear()
ret.num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=False)
is_prefill=False, enable_chunking=enable_chunking)
ret.decode_seq_groups_list.clear()
ret.prefill_seq_groups_list.clear()
@ -561,7 +561,7 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while not self._can_append_slots(seq_group):
while not self._can_append_slots(seq_group, enable_chunking):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
num_running_seqs = seq_group.get_max_num_running_seqs()
@ -611,7 +611,7 @@ class Scheduler:
if not cont_loop:
break
else:
self._append_slots(seq_group, blocks_to_copy)
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
@ -684,7 +684,8 @@ class Scheduler:
# If the sequence group cannot be swapped in, stop.
is_prefill = seq_group.is_prefill()
alloc_status = self.block_manager.can_swap_in(
seq_group, self._get_num_lookahead_slots(is_prefill))
seq_group,
self._get_num_lookahead_slots(is_prefill, enable_chunking))
if alloc_status == AllocStatus.LATER:
break
elif alloc_status == AllocStatus.NEVER:
@ -727,7 +728,7 @@ class Scheduler:
curr_loras.add(lora_int_id)
swapped_queue.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy)
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
if is_prefill:
prefill_seq_groups.append(
@ -747,12 +748,13 @@ class Scheduler:
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False),
is_prefill=False, enable_chunking=enable_chunking),
infeasible_seq_groups=infeasible_seq_groups,
)
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled:
if self.scheduler_config.chunked_prefill_enabled and \
not self.scheduler_config.is_multi_step:
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(self.scheduler_config.max_model_len,
@ -899,15 +901,21 @@ class Scheduler:
waiting_queue.popleft()
continue
num_lookahead_slots: int = 0
if self.scheduler_config.is_multi_step and enable_chunking:
num_lookahead_slots = self._get_num_lookahead_slots(
True, enable_chunking)
# If the sequence group cannot be allocated, stop.
can_allocate = self.block_manager.can_allocate(seq_group)
can_allocate = self.block_manager.can_allocate(
seq_group, num_lookahead_slots=num_lookahead_slots)
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds the capacity of block_manager",
num_new_tokens)
"Input prompt (%d tokens) + lookahead slots (%d) is "
"too long and exceeds the capacity of block_manager",
num_new_tokens, num_lookahead_slots)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@ -939,9 +947,24 @@ class Scheduler:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_group.init_multi_step(
num_scheduler_steps=self._get_num_lookahead_slots(
is_prefill=True) + 1)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
# init_multi_step_from_lookahead_slots happens in append_slots
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
# This assert will trip when a copy-on-write happens. This is
# not a concern as the very first sequence-group block
# allocation happens above. Still, we have the assert to
# catch any edge-cases.
assert not blocks_to_copy
else:
seq_group.init_multi_step_from_lookahead_slots(
num_lookahead_slots,
num_scheduler_steps=self.scheduler_config.
num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
@ -956,7 +979,8 @@ class Scheduler:
return SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
@ -1153,7 +1177,8 @@ class Scheduler:
else:
return self._schedule_default()
def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
@ -1164,13 +1189,17 @@ class Scheduler:
self.artificial_preempt_cnt -= 1
return False
# Appending slots only occurs in decoding.
is_prefill = False
is_prefill = seq_group.is_prefill()
num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill, enable_chunking)
if is_prefill and num_lookahead_slots > 0:
# Appending prefill slots only happens multi-step and
# chunked-prefill are enabled together.
assert self.scheduler_config.is_multi_step and enable_chunking
return self.block_manager.can_append_slots(
seq_group=seq_group,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
no_beam_search = seq_group.sampling_params is None or (
@ -1186,7 +1215,7 @@ class Scheduler:
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()
scheduler_outputs = self._schedule()
scheduler_outputs: SchedulerOutputs = self._schedule()
now = time.time()
if not self.cache_config.enable_prefix_caching:
@ -1383,11 +1412,10 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
) -> None:
def _append_slots(self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
@ -1398,11 +1426,25 @@ class Scheduler:
int is the destination block index. This list is updated with
the new source and destination block indices for the appended
slots.
enable_chunking (bool): True if chunked prefill is enabled.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
is_prefill: bool = seq_group.is_prefill()
num_lookahead_slots: int = self._get_num_lookahead_slots(
is_prefill, enable_chunking)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_group.init_multi_step_from_lookahead_slots(
num_lookahead_slots,
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking)
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
if self.scheduler_config.is_multi_step and enable_chunking:
# In multi-step chunked-prefill any sequence type can have
# slots appended.
seq_status = None
for seq in seq_group.get_seqs(status=seq_status):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
if len(cows) > 0:
blocks_to_copy.extend(cows)
@ -1513,16 +1555,32 @@ class Scheduler:
passed_delay = True
return passed_delay
def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
def _get_num_lookahead_slots(self, is_prefill: bool,
enable_chunking: bool) -> int:
"""The number of slots to allocate per sequence per step, beyond known
token ids. Speculative decoding uses these slots to store KV activations
of tokens which may or may not be accepted.
Speculative decoding does not yet support prefill, so we do not perform
lookahead allocation for prefill.
When chunking is enabled with multi-step, we allocate lookahead slots
for the prefills for when the prefills turn into decodes in the first
step.
"""
if is_prefill:
return 0
if self.scheduler_config.is_multi_step and enable_chunking:
# num_lookahead_slots was introduced in the context of decodes,
# in Speculative Decoding.
# When the num_scheduler_steps is 8, say, then the
# num_lookahead_slots is 7. Meaning, we are doing a 1-step of
# decode anyways and we wish to do 7 more.
#
# "lookaheads" for prefills, is introduced in support for
# Chunked-Prefill in Multi-Step.
return self.scheduler_config.num_lookahead_slots + 1
else:
return 0
return self.scheduler_config.num_lookahead_slots
@ -1565,6 +1623,16 @@ class Scheduler:
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
elif self.scheduler_config.is_multi_step:
if num_new_tokens > self._get_prompt_limit(seq_group):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else:
num_new_tokens = 0 \
if num_new_tokens > remaining_token_budget \
else num_new_tokens
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens

View File

@ -980,9 +980,13 @@ class EngineArgs:
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill:
raise ValueError("Chunked prefill is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill and self.enable_prefix_caching:
raise ValueError("Multi-Step is not supported with "
"both Chunked-Prefill and Prefix-Caching "
"enabled together.")
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1")
# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step

View File

@ -363,11 +363,18 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)
is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc:
assert len(

View File

@ -90,6 +90,12 @@ class OutputData(NamedTuple):
scheduler_outputs: SchedulerOutputs
is_async: bool
is_last_step: bool
# Indicates if this output is from the first step of the
# multi-step. When multi-step is disabled, this is always
# set to True.
# is_first_step_output is invalid when `outputs` has
# outputs from multiple steps.
is_first_step_output: Optional[bool]
skip: List[int]
@ -108,13 +114,15 @@ class SchedulerContext:
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool):
is_last_step: bool,
is_first_step_output: Optional[bool]):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=is_async,
is_last_step=is_last_step,
is_first_step_output=is_first_step_output,
skip=[]))
@ -237,9 +245,10 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
@ -270,6 +279,7 @@ class LLMEngine:
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
@ -957,8 +967,66 @@ class LLMEngine:
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
"""
def update_prefill_num_computed_tokens(
seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
is_first_step_output: Optional[bool]) -> None:
"""
When multi-step and chunked-prefill are enabled together, the
prefill sequence scheduled for multi-step execution turn into
decodes in the first step itself. This function accounts
for that conversion.
seq_group: SequenceGroup - A prefill seq_group
seq_group_meta: SequenceGroupMetadata - Metadata of the given
prefill seq_group
num_outputs: int - number of output tokens being processed for the
given seq_group
is_first_step_output: Optional[bool] -
If multi-step is enabled and num_outputs is 1, this value
indicates if this outputs belongs to the first step in the
multi-step.
If multi-step is enabled and num_outputs > 1, this value
must be None, as num_outputs > 1 indicates that outputs from
all the steps in multi-step are submitted in a single burst.
When multi-step is disabled, this value is always True.
"""
assert seq_group_meta.is_prompt
token_chunk_size = seq_group_meta.token_chunk_size
if num_outputs == 1:
assert is_first_step_output is not None
if seq_group_meta.state.num_steps == 1:
assert is_first_step_output is True
seq_group.update_num_computed_tokens(token_chunk_size)
return
# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
if is_first_step_output is True:
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
return
assert is_first_step_output is None
# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill. Outputs from all the steps are
# submitted in a single burst.
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
assert num_outputs == seq_group_meta.state.num_steps, \
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
now = time.time()
if len(ctx.output_queue) == 0:
@ -969,20 +1037,27 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue[0]
is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue.popleft()
is_last_step, is_first_step_output,
skip) = ctx.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if len(outputs) > 1:
has_multiple_outputs: bool = len(outputs) > 1
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list))
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
else:
outputs_by_sequence_group = outputs
@ -1018,14 +1093,17 @@ class LLMEngine:
finished_before.append(i)
continue
if len(outputs) > 1:
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
output = [outputs_by_sequence_group[0][i]]
if not is_async:
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if not is_async and seq_group_meta.is_prompt:
# Updates for all decodes happen when we actually append the
# token ids to the seq in process_outputs.
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
len(output),
is_first_step_output)
if outputs:
for o in outputs:
@ -1159,8 +1237,18 @@ class LLMEngine:
if seq_group.is_finished():
continue
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if seq_group_metadata.is_prompt:
if self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled:
# Prompts are scheduled in multi-step only when
# chunking is enabled. These prompts turn into
# decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens
# here.
pass
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
@ -1172,6 +1260,7 @@ class LLMEngine:
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
seq_group.update_num_computed_tokens(1)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
@ -1324,12 +1413,19 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)
is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc:
assert len(outputs) == 1, (

View File

@ -170,6 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
token_id=output_token_id,
logprobs=output_logprob,
)
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params)

View File

@ -743,10 +743,35 @@ class SequenceGroup:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def init_multi_step(self, num_scheduler_steps: int) -> None:
self.state.num_steps = num_scheduler_steps
def init_multi_step(self, num_steps: int) -> None:
self.state.num_steps = num_steps
self.state.current_step = 0
def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
num_scheduler_steps: int,
is_multi_step: bool,
enable_chunking: bool) -> None:
if not is_multi_step:
self.init_multi_step(num_steps=num_scheduler_steps)
return
# Multi-Step case
is_prefill = self.is_prefill()
# The asserts below reflect the expectations of the current system.
if is_prefill and enable_chunking:
assert num_lookahead_slots == num_scheduler_steps
self.init_multi_step(num_steps=num_lookahead_slots)
else:
is_decode: bool = not is_prefill
# If it is a prefill, num_lookahead_slots must be 0
assert num_lookahead_slots == 0 or is_decode
# If it is a decode, num_lookahead_slots + 1 must match
# the scheduler steps.
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
self.init_multi_step(num_steps=num_lookahead_slots + 1)
def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
@ -1010,6 +1035,20 @@ class SequenceGroupMetadata(
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
# Multi-Step Chunked-Prefill property
@property
def is_single_step_prompt(self) -> bool:
# do_sample is true, only when the token_chunk_size matches the
# num_uncomputed_tokens of the sequence. This indicates that
# the prompt will finish processing in a single `execute_model`
# step.
return self.is_prompt and self.do_sample
def get_first_seq_id(self) -> int:
# This is an efficient way of fetching the seq_id when
# we know this SequenceGroup has only one sequence.
return next(iter(self.seq_data))
def apply_delta(self,
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
@ -1022,7 +1061,8 @@ class SequenceGroupMetadata(
def finish_step(self) -> None:
assert self.state is not None
assert self.state.current_step < self.state.num_steps
assert self.state.current_step < self.state.num_steps, \
f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
self.state.current_step += 1

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache
from vllm.utils import PyObjectCache, async_tensor_h2d
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
@ -30,6 +30,14 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"]
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]:
if chunked_prefill_enabled:
return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
else:
return MULTI_STEP_ATTENTION_BACKENDS
def seq_output_builder():
@ -144,11 +152,13 @@ class StatefulModelInput(BroadcastableModelInput):
is_multi_step: bool = True
is_last_step: bool = False
is_first_multi_step: bool = False
base_output_proc_callback: Optional[Callable] = None
# ping-pong data structures for multi-step to wait on the previous step
step_cuda_events: List[torch.cuda.Event] = field(
default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2)
num_seqs: int = -1
num_queries: int = -1
num_single_step_prefills: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
assert self.frozen_model_input is not None
@ -161,6 +171,7 @@ class StatefulModelInput(BroadcastableModelInput):
'is_first_multi_step': self.is_first_multi_step,
'num_seqs': self.num_seqs,
'num_queries': self.num_queries,
'num_single_step_prefills': self.num_single_step_prefills,
}
tensor_dict.update(new_tensor_dict)
return tensor_dict
@ -209,6 +220,81 @@ class StatefulModelInput(BroadcastableModelInput):
sampled_token_ids=sampled_token_ids,
pythonized=False))
def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool):
"""
sampling_metadata.selected_token_indices is constructed for the
first-step in Multi-Step. However, when chunked-prefill is enabled with
multi-step, the scheduled prompts are fully processed in the
first-step and are processed as decodes in the rest of the steps.
This function updates the sampling_metadata.selected_token_indices
to account for this conversion.
Example:
Let 2 prompts and 2 decodes be scheduled together. Let the
num-tokens to process for the 2 prompts be 5 and 8 respectively.
In that case, sampling_metadata.sampled_token_indices will be,
[4, 12, 13, 14] as it is constructed for the first-step in
multi-step.
However, the prompts turns to decodes after the first-step
and the num-tokens for the previously-prompt sequences will
be 1 and 1 as they are decodes now. The self.sampled_token_indices
must be updated to [0,1,2,3].
"""
assert self.current_step == 1 and self.num_single_step_prefills > 0
if not get_pp_group().is_last_rank:
return
assert self.frozen_model_input is not None
assert self.frozen_model_input.sampling_metadata is not None
self.frozen_model_input.sampling_metadata.selected_token_indices = \
async_tensor_h2d(list(range(self.num_queries)),
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
"""
Advancing the datastructures of StatefulModelInput::frozen_model_input
is only required when prefills are scheduled with decodes to run in
multi-step. This advancement/correction is required to account for
the conversion of Prefills to Decodes after the first multi-step.
"""
if self.current_step != 1 or self.num_single_step_prefills == 0:
return
assert self.frozen_model_input is not None
fmi = self.frozen_model_input
# Truncate input_tokens
assert fmi.input_tokens is not None
assert fmi.input_tokens.shape[0] >= self.num_seqs
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
# Update frozen_model_input::input_positons.
assert fmi.input_positions is not None
assert fmi.input_positions.shape[0] >= self.num_seqs
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
num_seqs]
# Assert unsupported
assert fmi.lora_mapping is None
assert fmi.lora_requests is not None
assert len(fmi.lora_requests) == 0
assert fmi.attn_metadata is not None
assert fmi.prompt_adapter_mapping is None
assert fmi.prompt_adapter_requests is not None
assert len(fmi.prompt_adapter_requests) == 0
assert fmi.multi_modal_kwargs is not None
assert len(fmi.multi_modal_kwargs) == 0
self.frozen_model_input = dataclasses.replace(
self.frozen_model_input,
input_tokens=fmi_new_input_tokens,
input_positions=fmi_new_input_positions)
self.maybe_advance_sampling_metadata(device, pin_memory)
# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
@ -220,6 +306,19 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check attention backend support.
supported_attention_backends: List[str] = \
_get_supported_attention_backends(
self.scheduler_config.chunked_prefill_enabled)
if self.attn_backend.get_name() not in supported_attention_backends:
ms_config_str: str = "Multi-Step + Chunked-Prefill" \
if self.scheduler_config.chunked_prefill_enabled \
else "Multi-Step"
raise ValueError(
f"{ms_config_str} not supported for attention backend: "
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
f"to a value from {supported_attention_backends}.")
# uses the base model runner to execute the model and wraps it with
# multi-step logic
self._base_model_runner: GPUModelRunnerBase = base_model_runner
@ -248,14 +347,25 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> StatefulModelInput:
frozen_model_input = self._base_model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
frozen_model_input: ModelInputForGPUWithSamplingMetadata = \
self._base_model_runner.prepare_model_input(
seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
assert frozen_model_input.query_lens is not None
assert frozen_model_input.seq_lens is not None
assert frozen_model_input.attn_metadata is not None
num_queries = len(frozen_model_input.query_lens)
num_seqs = len(frozen_model_input.seq_lens)
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
num_seqs=len(frozen_model_input.seq_lens),
num_queries=len(frozen_model_input.query_lens),
)
num_seqs=num_seqs,
num_queries=num_queries,
num_single_step_prefills=num_single_step_prefills)
return model_input
def _async_process_outputs(self, model_input: StatefulModelInput,
@ -265,7 +375,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback()
cont = True
for model_output in model_input.cached_outputs:
for step_num, model_output in enumerate(model_input.cached_outputs):
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
@ -276,7 +386,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
is_last_step=False,
is_first_step_output=step_num == 0)
output_proc_callback()
else:
@ -292,9 +403,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
has_async_callback = output_proc_callback is not None
outputs = []
for output_id in range(len(model_input.cached_outputs)):
output = model_input.cached_outputs[output_id]
is_last_step = output_id == len(model_input.cached_outputs) - 1
for step_num, output in enumerate(model_input.cached_outputs):
is_last_step = step_num == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
@ -323,7 +433,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
is_last_step=False,
is_first_step_output=step_num == 0)
else:
outputs.append(output.sampler_output)
else:
@ -389,18 +500,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output)
output_proc_callback = None
# frozen_model_input may have been updated
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
if model_input.base_output_proc_callback is None:
assert frozen_model_input is not None
model_input.base_output_proc_callback = \
frozen_model_input.async_callback
if frozen_model_input.async_callback is not None:
output_proc_callback = frozen_model_input.async_callback
assert output_proc_callback is not None
assert model_input.base_output_proc_callback is not None
async_callback = functools.partial(
self._async_process_outputs,
model_input=model_input,
output_proc_callback=output_proc_callback)
output_proc_callback=model_input.base_output_proc_callback)
frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
# Update the local instance
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
# Execute the model
@ -455,8 +575,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = self._final_process_outputs(model_input,
output_proc_callback)
outputs = self._final_process_outputs(
model_input, model_input.base_output_proc_callback)
self.pythonization_cache.reset()
return outputs
@ -484,11 +604,14 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
raise ValueError(
f"Multi-step not supported for attention backend: "
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
model_input.maybe_advance_frozen_model_input(self.device,
self.pin_memory)
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.input_tokens is not None
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
assert frozen_model_input.attn_metadata is not None
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
num_seqs = model_input.num_seqs
@ -498,13 +621,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
attn_metadata = frozen_model_input.attn_metadata
assert attn_metadata is not None
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
model_input.num_single_step_prefills != 0
attn_metadata.advance_step(
frozen_model_input,
sampled_token_ids,
self.block_size,
num_seqs,
num_queries,
)
turn_prefills_into_decodes=turn_prefills_into_decodes)
return model_input

View File

@ -76,8 +76,9 @@ class MultiStepWorker(Worker):
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
# clear the cached decode metadata so that it can be recomputed on
# the workers
# clear the cached metadata so that it can be recomputed on
# the workers.
frozen_model_input.attn_metadata._cached_prefill_metadata = None
frozen_model_input.attn_metadata._cached_decode_metadata = None
model_input.is_first_multi_step = is_first_multi_step