[Core] Multi-Step + Single Step Prefills via Chunked Prefill code path (#8378)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
c5d55356f9
commit
c2ec430ab5
@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel(
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
}
|
||||
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
bool size_0_cond = true;
|
||||
|
@ -37,6 +37,7 @@ DEFAULT_SERVER_ARGS: List[str] = [
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("is_async", [True])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(
|
||||
example_prompts,
|
||||
@ -49,6 +50,7 @@ async def test_multi_step(
|
||||
is_async: bool,
|
||||
num_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
enable_chunked_prefill: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
||||
@ -74,6 +76,10 @@ async def test_multi_step(
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> no logprobs
|
||||
"""
|
||||
if enable_chunked_prefill and \
|
||||
(pp_size > 1 or attention_backend != "FLASH_ATTN"):
|
||||
pytest.skip("Multi-step with Chunked-Prefill only supports"
|
||||
"PP=1 and FLASH_ATTN backend")
|
||||
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
@ -93,6 +99,9 @@ async def test_multi_step(
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
if enable_chunked_prefill:
|
||||
ms_server_args.append("--enable-chunked-prefill")
|
||||
|
||||
distributed_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
|
@ -16,6 +16,7 @@ NUM_PROMPTS = [10]
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@ -28,6 +29,7 @@ def test_multi_step_llm(
|
||||
model: str,
|
||||
dtype: str,
|
||||
tp_size: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_tokens: int,
|
||||
enforce_eager: int,
|
||||
num_scheduler_steps: int,
|
||||
@ -51,6 +53,7 @@ def test_multi_step_llm(
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
dtype: tensor datatype for engine to utilize
|
||||
tp_size: degree of tensor-parallelism
|
||||
enable_chunked_prefill: chunked-prefill on/off
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
enforce_eager
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
@ -73,6 +76,7 @@ def test_multi_step_llm(
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
) as vllm_model:
|
||||
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
|
||||
|
@ -342,9 +342,13 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int, num_seqs: int, num_queries: int):
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
@ -355,6 +359,23 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
@ -366,7 +387,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
@ -706,8 +726,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
|
@ -410,18 +410,22 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
|
||||
return self
|
||||
|
||||
def advance_step(
|
||||
self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
):
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
|
||||
assert not turn_prefills_into_decodes, \
|
||||
("Chunked prefill is not supported with flashinfer yet."
|
||||
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
||||
"specific parameter.")
|
||||
|
||||
assert num_seqs > 0
|
||||
assert num_queries > 0
|
||||
assert model_input.attn_metadata is not None
|
||||
|
@ -983,9 +983,16 @@ class SchedulerConfig:
|
||||
policy: str = "fcfs") -> None:
|
||||
if max_num_batched_tokens is None:
|
||||
if enable_chunked_prefill:
|
||||
# It is the values that have the best balance between ITL
|
||||
# and TTFT on A100. Note it is not optimized for throughput.
|
||||
max_num_batched_tokens = 512
|
||||
if num_scheduler_steps > 1:
|
||||
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
|
||||
# for now. Have max_num_batched_tokens set to max_model_len
|
||||
# so we don't reject sequences on account of a short
|
||||
# max_num_batched_tokens.
|
||||
max_num_batched_tokens = max(max_model_len, 2048)
|
||||
else:
|
||||
# It is the values that have the best balance between ITL
|
||||
# and TTFT on A100. Note it is not optimized for throughput.
|
||||
max_num_batched_tokens = 512
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value
|
||||
# for higher throughput.
|
||||
|
@ -55,9 +55,12 @@ class BlockTable:
|
||||
self._num_full_slots = self._get_num_token_ids()
|
||||
|
||||
@staticmethod
|
||||
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
|
||||
def get_num_required_blocks(token_ids: List[int],
|
||||
block_size: int,
|
||||
num_lookahead_slots: int = 0) -> int:
|
||||
"""Calculates the minimum number of blocks required to store a given
|
||||
sequence of token IDs.
|
||||
sequence of token IDs along with any look-ahead slots that may be
|
||||
required (like in multi-step + chunked-prefill).
|
||||
|
||||
This assumes worst-case scenario, where every block requires a new
|
||||
allocation (e.g. ignoring prefix caching).
|
||||
@ -66,12 +69,14 @@ class BlockTable:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
block_size (int): The maximum number of tokens that can be stored in
|
||||
a single block.
|
||||
num_lookahead_slots (int): look-ahead slots that the sequence may
|
||||
require.
|
||||
|
||||
Returns:
|
||||
int: The minimum number of blocks required to store the given
|
||||
sequence of token IDs.
|
||||
sequence of token IDs along with any required look-ahead slots.
|
||||
"""
|
||||
return cdiv(len(token_ids), block_size)
|
||||
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
|
||||
|
||||
def allocate(self,
|
||||
token_ids: List[int],
|
||||
|
@ -281,10 +281,15 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
|
||||
return 0 if seq is None else seq.n_blocks
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
assert (num_lookahead_slots == 0
|
||||
), "lookahead allocation not supported in BlockSpaceManagerV1"
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
self_num_required_blocks = self._get_seq_num_required_blocks(
|
||||
|
@ -107,7 +107,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
||||
self.block_allocator)
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
@ -117,6 +119,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||
seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
|
@ -21,7 +21,9 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# Always return OK for dummy purposes
|
||||
return AllocStatus.OK
|
||||
|
||||
|
@ -44,7 +44,9 @@ class BlockSpaceManager(ABC):
|
||||
raise ValueError(f"Unknown version {version=}")
|
||||
|
||||
@abstractmethod
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -522,7 +522,7 @@ class Scheduler:
|
||||
ret.swapped_out.clear()
|
||||
|
||||
ret.num_lookahead_slots = self._get_num_lookahead_slots(
|
||||
is_prefill=False)
|
||||
is_prefill=False, enable_chunking=enable_chunking)
|
||||
|
||||
ret.decode_seq_groups_list.clear()
|
||||
ret.prefill_seq_groups_list.clear()
|
||||
@ -561,7 +561,7 @@ class Scheduler:
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available
|
||||
# slot to keep all the sequence groups in the RUNNING state.
|
||||
while not self._can_append_slots(seq_group):
|
||||
while not self._can_append_slots(seq_group, enable_chunking):
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id,
|
||||
num_running_tokens)
|
||||
num_running_seqs = seq_group.get_max_num_running_seqs()
|
||||
@ -611,7 +611,7 @@ class Scheduler:
|
||||
if not cont_loop:
|
||||
break
|
||||
else:
|
||||
self._append_slots(seq_group, blocks_to_copy)
|
||||
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
||||
is_prefill = seq_group.is_prefill()
|
||||
|
||||
scheduled_seq_group: ScheduledSequenceGroup = \
|
||||
@ -684,7 +684,8 @@ class Scheduler:
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
is_prefill = seq_group.is_prefill()
|
||||
alloc_status = self.block_manager.can_swap_in(
|
||||
seq_group, self._get_num_lookahead_slots(is_prefill))
|
||||
seq_group,
|
||||
self._get_num_lookahead_slots(is_prefill, enable_chunking))
|
||||
if alloc_status == AllocStatus.LATER:
|
||||
break
|
||||
elif alloc_status == AllocStatus.NEVER:
|
||||
@ -727,7 +728,7 @@ class Scheduler:
|
||||
curr_loras.add(lora_int_id)
|
||||
swapped_queue.popleft()
|
||||
self._swap_in(seq_group, blocks_to_swap_in)
|
||||
self._append_slots(seq_group, blocks_to_copy)
|
||||
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
||||
is_prefill = seq_group.is_prefill()
|
||||
if is_prefill:
|
||||
prefill_seq_groups.append(
|
||||
@ -747,12 +748,13 @@ class Scheduler:
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=False),
|
||||
is_prefill=False, enable_chunking=enable_chunking),
|
||||
infeasible_seq_groups=infeasible_seq_groups,
|
||||
)
|
||||
|
||||
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
if self.scheduler_config.chunked_prefill_enabled and \
|
||||
not self.scheduler_config.is_multi_step:
|
||||
prompt_limit = self.scheduler_config.max_model_len
|
||||
else:
|
||||
prompt_limit = min(self.scheduler_config.max_model_len,
|
||||
@ -899,15 +901,21 @@ class Scheduler:
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
|
||||
num_lookahead_slots: int = 0
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(
|
||||
True, enable_chunking)
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||
can_allocate = self.block_manager.can_allocate(
|
||||
seq_group, num_lookahead_slots=num_lookahead_slots)
|
||||
if can_allocate == AllocStatus.LATER:
|
||||
break
|
||||
elif can_allocate == AllocStatus.NEVER:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) is too long"
|
||||
" and exceeds the capacity of block_manager",
|
||||
num_new_tokens)
|
||||
"Input prompt (%d tokens) + lookahead slots (%d) is "
|
||||
"too long and exceeds the capacity of block_manager",
|
||||
num_new_tokens, num_lookahead_slots)
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
@ -939,9 +947,24 @@ class Scheduler:
|
||||
curr_loras.add(lora_int_id)
|
||||
waiting_queue.popleft()
|
||||
self._allocate_and_set_running(seq_group)
|
||||
seq_group.init_multi_step(
|
||||
num_scheduler_steps=self._get_num_lookahead_slots(
|
||||
is_prefill=True) + 1)
|
||||
|
||||
if enable_chunking and self.scheduler_config.is_multi_step:
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
# init_multi_step_from_lookahead_slots happens in append_slots
|
||||
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
||||
# This assert will trip when a copy-on-write happens. This is
|
||||
# not a concern as the very first sequence-group block
|
||||
# allocation happens above. Still, we have the assert to
|
||||
# catch any edge-cases.
|
||||
assert not blocks_to_copy
|
||||
else:
|
||||
seq_group.init_multi_step_from_lookahead_slots(
|
||||
num_lookahead_slots,
|
||||
num_scheduler_steps=self.scheduler_config.
|
||||
num_scheduler_steps,
|
||||
is_multi_step=self.scheduler_config.is_multi_step,
|
||||
enable_chunking=enable_chunking)
|
||||
|
||||
seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group=seq_group,
|
||||
token_chunk_size=num_new_tokens))
|
||||
@ -956,7 +979,8 @@ class Scheduler:
|
||||
return SchedulerPrefillOutputs(
|
||||
seq_groups=seq_groups,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=True, enable_chunking=enable_chunking))
|
||||
|
||||
def _schedule_default(self) -> SchedulerOutputs:
|
||||
"""Schedule queued requests.
|
||||
@ -1153,7 +1177,8 @@ class Scheduler:
|
||||
else:
|
||||
return self._schedule_default()
|
||||
|
||||
def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
|
||||
def _can_append_slots(self, seq_group: SequenceGroup,
|
||||
enable_chunking: bool) -> bool:
|
||||
"""Determine whether or not we have enough space in the KV cache to
|
||||
continue generation of the sequence group.
|
||||
"""
|
||||
@ -1164,13 +1189,17 @@ class Scheduler:
|
||||
self.artificial_preempt_cnt -= 1
|
||||
return False
|
||||
|
||||
# Appending slots only occurs in decoding.
|
||||
is_prefill = False
|
||||
is_prefill = seq_group.is_prefill()
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(
|
||||
is_prefill, enable_chunking)
|
||||
|
||||
if is_prefill and num_lookahead_slots > 0:
|
||||
# Appending prefill slots only happens multi-step and
|
||||
# chunked-prefill are enabled together.
|
||||
assert self.scheduler_config.is_multi_step and enable_chunking
|
||||
|
||||
return self.block_manager.can_append_slots(
|
||||
seq_group=seq_group,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
|
||||
)
|
||||
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
|
||||
|
||||
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
|
||||
no_beam_search = seq_group.sampling_params is None or (
|
||||
@ -1186,7 +1215,7 @@ class Scheduler:
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_start_time = time.perf_counter()
|
||||
|
||||
scheduler_outputs = self._schedule()
|
||||
scheduler_outputs: SchedulerOutputs = self._schedule()
|
||||
now = time.time()
|
||||
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
@ -1383,11 +1412,10 @@ class Scheduler:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _append_slots(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
def _append_slots(self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
enable_chunking: bool = False) -> None:
|
||||
"""Appends new slots to the sequences in the given sequence group.
|
||||
|
||||
Args:
|
||||
@ -1398,11 +1426,25 @@ class Scheduler:
|
||||
int is the destination block index. This list is updated with
|
||||
the new source and destination block indices for the appended
|
||||
slots.
|
||||
enable_chunking (bool): True if chunked prefill is enabled.
|
||||
"""
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
|
||||
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
|
||||
is_prefill: bool = seq_group.is_prefill()
|
||||
num_lookahead_slots: int = self._get_num_lookahead_slots(
|
||||
is_prefill, enable_chunking)
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_group.init_multi_step_from_lookahead_slots(
|
||||
num_lookahead_slots,
|
||||
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
|
||||
is_multi_step=self.scheduler_config.is_multi_step,
|
||||
enable_chunking=enable_chunking)
|
||||
|
||||
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
# In multi-step chunked-prefill any sequence type can have
|
||||
# slots appended.
|
||||
seq_status = None
|
||||
|
||||
for seq in seq_group.get_seqs(status=seq_status):
|
||||
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
||||
if len(cows) > 0:
|
||||
blocks_to_copy.extend(cows)
|
||||
@ -1513,16 +1555,32 @@ class Scheduler:
|
||||
passed_delay = True
|
||||
return passed_delay
|
||||
|
||||
def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
|
||||
def _get_num_lookahead_slots(self, is_prefill: bool,
|
||||
enable_chunking: bool) -> int:
|
||||
"""The number of slots to allocate per sequence per step, beyond known
|
||||
token ids. Speculative decoding uses these slots to store KV activations
|
||||
of tokens which may or may not be accepted.
|
||||
|
||||
Speculative decoding does not yet support prefill, so we do not perform
|
||||
lookahead allocation for prefill.
|
||||
|
||||
When chunking is enabled with multi-step, we allocate lookahead slots
|
||||
for the prefills for when the prefills turn into decodes in the first
|
||||
step.
|
||||
"""
|
||||
if is_prefill:
|
||||
return 0
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
# num_lookahead_slots was introduced in the context of decodes,
|
||||
# in Speculative Decoding.
|
||||
# When the num_scheduler_steps is 8, say, then the
|
||||
# num_lookahead_slots is 7. Meaning, we are doing a 1-step of
|
||||
# decode anyways and we wish to do 7 more.
|
||||
#
|
||||
# "lookaheads" for prefills, is introduced in support for
|
||||
# Chunked-Prefill in Multi-Step.
|
||||
return self.scheduler_config.num_lookahead_slots + 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
return self.scheduler_config.num_lookahead_slots
|
||||
|
||||
@ -1565,6 +1623,16 @@ class Scheduler:
|
||||
if remaining_token_budget < num_new_tokens:
|
||||
num_new_tokens = (remaining_token_budget //
|
||||
block_size) * block_size
|
||||
elif self.scheduler_config.is_multi_step:
|
||||
if num_new_tokens > self._get_prompt_limit(seq_group):
|
||||
# If the seq_group is in prompt-stage, pass the
|
||||
# num_new_tokens as-is so the caller can ignore
|
||||
# the sequence.
|
||||
pass
|
||||
else:
|
||||
num_new_tokens = 0 \
|
||||
if num_new_tokens > remaining_token_budget \
|
||||
else num_new_tokens
|
||||
else:
|
||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
||||
return num_new_tokens
|
||||
|
@ -980,9 +980,13 @@ class EngineArgs:
|
||||
if speculative_config is not None:
|
||||
raise ValueError("Speculative decoding is not supported with "
|
||||
"multi-step (--num-scheduler-steps > 1)")
|
||||
if self.enable_chunked_prefill:
|
||||
raise ValueError("Chunked prefill is not supported with "
|
||||
"multi-step (--num-scheduler-steps > 1)")
|
||||
if self.enable_chunked_prefill and self.enable_prefix_caching:
|
||||
raise ValueError("Multi-Step is not supported with "
|
||||
"both Chunked-Prefill and Prefix-Caching "
|
||||
"enabled together.")
|
||||
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
|
||||
raise ValueError("Multi-Step Chunked-Prefill is not supported "
|
||||
"for pipeline-parallel-size > 1")
|
||||
|
||||
# make sure num_lookahead_slots is set the higher value depending on
|
||||
# if we are using speculative decoding or multi-step
|
||||
|
@ -363,11 +363,18 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=allow_async_output_proc,
|
||||
is_last_step=True)
|
||||
is_last_step=True,
|
||||
is_first_step_output=is_first_step_output)
|
||||
|
||||
if outputs and allow_async_output_proc:
|
||||
assert len(
|
||||
|
@ -90,6 +90,12 @@ class OutputData(NamedTuple):
|
||||
scheduler_outputs: SchedulerOutputs
|
||||
is_async: bool
|
||||
is_last_step: bool
|
||||
# Indicates if this output is from the first step of the
|
||||
# multi-step. When multi-step is disabled, this is always
|
||||
# set to True.
|
||||
# is_first_step_output is invalid when `outputs` has
|
||||
# outputs from multiple steps.
|
||||
is_first_step_output: Optional[bool]
|
||||
skip: List[int]
|
||||
|
||||
|
||||
@ -108,13 +114,15 @@ class SchedulerContext:
|
||||
def append_output(self, outputs: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||
is_last_step: bool):
|
||||
is_last_step: bool,
|
||||
is_first_step_output: Optional[bool]):
|
||||
self.output_queue.append(
|
||||
OutputData(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=is_async,
|
||||
is_last_step=is_last_step,
|
||||
is_first_step_output=is_first_step_output,
|
||||
skip=[]))
|
||||
|
||||
|
||||
@ -237,9 +245,10 @@ class LLMEngine:
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
|
||||
"enable_prefix_caching=%s, use_async_output_proc=%s, "
|
||||
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
|
||||
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
|
||||
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s, use_cached_outputs=%s, "
|
||||
"mm_processor_kwargs=%s)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
@ -270,6 +279,7 @@ class LLMEngine:
|
||||
model_config.served_model_name,
|
||||
scheduler_config.use_v2_block_manager,
|
||||
scheduler_config.num_scheduler_steps,
|
||||
scheduler_config.chunked_prefill_enabled,
|
||||
scheduler_config.multi_step_stream_outputs,
|
||||
cache_config.enable_prefix_caching,
|
||||
model_config.use_async_output_proc,
|
||||
@ -957,8 +967,66 @@ class LLMEngine:
|
||||
|
||||
ctx: The virtual engine context to work on
|
||||
request_id: If provided, then only this request is going to be processed
|
||||
|
||||
"""
|
||||
|
||||
def update_prefill_num_computed_tokens(
|
||||
seq_group: SequenceGroup,
|
||||
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
|
||||
is_first_step_output: Optional[bool]) -> None:
|
||||
"""
|
||||
When multi-step and chunked-prefill are enabled together, the
|
||||
prefill sequence scheduled for multi-step execution turn into
|
||||
decodes in the first step itself. This function accounts
|
||||
for that conversion.
|
||||
|
||||
seq_group: SequenceGroup - A prefill seq_group
|
||||
seq_group_meta: SequenceGroupMetadata - Metadata of the given
|
||||
prefill seq_group
|
||||
num_outputs: int - number of output tokens being processed for the
|
||||
given seq_group
|
||||
is_first_step_output: Optional[bool] -
|
||||
If multi-step is enabled and num_outputs is 1, this value
|
||||
indicates if this outputs belongs to the first step in the
|
||||
multi-step.
|
||||
If multi-step is enabled and num_outputs > 1, this value
|
||||
must be None, as num_outputs > 1 indicates that outputs from
|
||||
all the steps in multi-step are submitted in a single burst.
|
||||
When multi-step is disabled, this value is always True.
|
||||
"""
|
||||
|
||||
assert seq_group_meta.is_prompt
|
||||
|
||||
token_chunk_size = seq_group_meta.token_chunk_size
|
||||
|
||||
if num_outputs == 1:
|
||||
assert is_first_step_output is not None
|
||||
|
||||
if seq_group_meta.state.num_steps == 1:
|
||||
assert is_first_step_output is True
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
return
|
||||
|
||||
# multi-step prefill is only supported when multi-step is
|
||||
# enabled with chunked prefill
|
||||
assert self.scheduler_config.is_multi_step and \
|
||||
self.scheduler_config.chunked_prefill_enabled
|
||||
if is_first_step_output is True:
|
||||
# This sequence is a prompt during the first step only.
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
return
|
||||
|
||||
assert is_first_step_output is None
|
||||
|
||||
# multi-step prefill is only supported when multi-step is
|
||||
# enabled with chunked prefill. Outputs from all the steps are
|
||||
# submitted in a single burst.
|
||||
assert self.scheduler_config.is_multi_step and \
|
||||
self.scheduler_config.chunked_prefill_enabled
|
||||
assert num_outputs == seq_group_meta.state.num_steps, \
|
||||
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
|
||||
# This sequence is a prompt during the first step only.
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if len(ctx.output_queue) == 0:
|
||||
@ -969,20 +1037,27 @@ class LLMEngine:
|
||||
# When we process only one request, no pop is required
|
||||
# (since later we will process all of the rest)
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step, skip) = ctx.output_queue[0]
|
||||
is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
|
||||
else:
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step, skip) = ctx.output_queue.popleft()
|
||||
is_last_step, is_first_step_output,
|
||||
skip) = ctx.output_queue.popleft()
|
||||
|
||||
# Sanity check
|
||||
assert len(seq_group_metadata_list) == len(
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
if len(outputs) > 1:
|
||||
has_multiple_outputs: bool = len(outputs) > 1
|
||||
if has_multiple_outputs:
|
||||
assert self.scheduler_config.is_multi_step or \
|
||||
self.speculative_config
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, num_seq_groups=len(seq_group_metadata_list))
|
||||
# We have outputs for multiple steps submitted in a single burst,
|
||||
# so invalidate is_first_step_output.
|
||||
is_first_step_output = None
|
||||
else:
|
||||
outputs_by_sequence_group = outputs
|
||||
|
||||
@ -1018,14 +1093,17 @@ class LLMEngine:
|
||||
finished_before.append(i)
|
||||
continue
|
||||
|
||||
if len(outputs) > 1:
|
||||
if has_multiple_outputs:
|
||||
output = outputs_by_sequence_group[i]
|
||||
else:
|
||||
output = [outputs_by_sequence_group[0][i]]
|
||||
|
||||
if not is_async:
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
if not is_async and seq_group_meta.is_prompt:
|
||||
# Updates for all decodes happen when we actually append the
|
||||
# token ids to the seq in process_outputs.
|
||||
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
|
||||
len(output),
|
||||
is_first_step_output)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@ -1159,8 +1237,18 @@ class LLMEngine:
|
||||
if seq_group.is_finished():
|
||||
continue
|
||||
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
if seq_group_metadata.is_prompt:
|
||||
if self.scheduler_config.is_multi_step and \
|
||||
self.scheduler_config.chunked_prefill_enabled:
|
||||
# Prompts are scheduled in multi-step only when
|
||||
# chunking is enabled. These prompts turn into
|
||||
# decodes after the very first step. Therefore,
|
||||
# we skip the update to the num_computed_tokens
|
||||
# here.
|
||||
pass
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
@ -1172,6 +1260,7 @@ class LLMEngine:
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
@ -1324,12 +1413,19 @@ class LLMEngine:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
# Add results to the output_queue
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=allow_async_output_proc,
|
||||
is_last_step=True)
|
||||
is_last_step=True,
|
||||
is_first_step_output=is_first_step_output)
|
||||
|
||||
if outputs and allow_async_output_proc:
|
||||
assert len(outputs) == 1, (
|
||||
|
@ -170,6 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
token_id=output_token_id,
|
||||
logprobs=output_logprob,
|
||||
)
|
||||
seq.data.update_num_computed_tokens(1)
|
||||
|
||||
self._process_decode_and_stop(seq, sampling_params)
|
||||
|
||||
|
@ -743,10 +743,35 @@ class SequenceGroup:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def init_multi_step(self, num_scheduler_steps: int) -> None:
|
||||
self.state.num_steps = num_scheduler_steps
|
||||
def init_multi_step(self, num_steps: int) -> None:
|
||||
self.state.num_steps = num_steps
|
||||
self.state.current_step = 0
|
||||
|
||||
def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
|
||||
num_scheduler_steps: int,
|
||||
is_multi_step: bool,
|
||||
enable_chunking: bool) -> None:
|
||||
|
||||
if not is_multi_step:
|
||||
self.init_multi_step(num_steps=num_scheduler_steps)
|
||||
return
|
||||
|
||||
# Multi-Step case
|
||||
is_prefill = self.is_prefill()
|
||||
|
||||
# The asserts below reflect the expectations of the current system.
|
||||
if is_prefill and enable_chunking:
|
||||
assert num_lookahead_slots == num_scheduler_steps
|
||||
self.init_multi_step(num_steps=num_lookahead_slots)
|
||||
else:
|
||||
is_decode: bool = not is_prefill
|
||||
# If it is a prefill, num_lookahead_slots must be 0
|
||||
assert num_lookahead_slots == 0 or is_decode
|
||||
# If it is a decode, num_lookahead_slots + 1 must match
|
||||
# the scheduler steps.
|
||||
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
|
||||
self.init_multi_step(num_steps=num_lookahead_slots + 1)
|
||||
|
||||
def get_last_latency(self, now: float) -> Optional[float]:
|
||||
"""Sets the last token time for Request level timings."""
|
||||
# If still in prefill phase, raise Error.
|
||||
@ -1010,6 +1035,20 @@ class SequenceGroupMetadata(
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
# Multi-Step Chunked-Prefill property
|
||||
@property
|
||||
def is_single_step_prompt(self) -> bool:
|
||||
# do_sample is true, only when the token_chunk_size matches the
|
||||
# num_uncomputed_tokens of the sequence. This indicates that
|
||||
# the prompt will finish processing in a single `execute_model`
|
||||
# step.
|
||||
return self.is_prompt and self.do_sample
|
||||
|
||||
def get_first_seq_id(self) -> int:
|
||||
# This is an efficient way of fetching the seq_id when
|
||||
# we know this SequenceGroup has only one sequence.
|
||||
return next(iter(self.seq_data))
|
||||
|
||||
def apply_delta(self,
|
||||
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
|
||||
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
|
||||
@ -1022,7 +1061,8 @@ class SequenceGroupMetadata(
|
||||
|
||||
def finish_step(self) -> None:
|
||||
assert self.state is not None
|
||||
assert self.state.current_step < self.state.num_steps
|
||||
assert self.state.current_step < self.state.num_steps, \
|
||||
f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
|
||||
self.state.current_step += 1
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
get_pythonized_sample_results)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import PyObjectCache
|
||||
from vllm.utils import PyObjectCache, async_tensor_h2d
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -30,6 +30,14 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
|
||||
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"]
|
||||
|
||||
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
|
||||
-> List[str]:
|
||||
if chunked_prefill_enabled:
|
||||
return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
|
||||
else:
|
||||
return MULTI_STEP_ATTENTION_BACKENDS
|
||||
|
||||
|
||||
def seq_output_builder():
|
||||
@ -144,11 +152,13 @@ class StatefulModelInput(BroadcastableModelInput):
|
||||
is_multi_step: bool = True
|
||||
is_last_step: bool = False
|
||||
is_first_multi_step: bool = False
|
||||
base_output_proc_callback: Optional[Callable] = None
|
||||
# ping-pong data structures for multi-step to wait on the previous step
|
||||
step_cuda_events: List[torch.cuda.Event] = field(
|
||||
default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2)
|
||||
num_seqs: int = -1
|
||||
num_queries: int = -1
|
||||
num_single_step_prefills: int = 0
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
assert self.frozen_model_input is not None
|
||||
@ -161,6 +171,7 @@ class StatefulModelInput(BroadcastableModelInput):
|
||||
'is_first_multi_step': self.is_first_multi_step,
|
||||
'num_seqs': self.num_seqs,
|
||||
'num_queries': self.num_queries,
|
||||
'num_single_step_prefills': self.num_single_step_prefills,
|
||||
}
|
||||
tensor_dict.update(new_tensor_dict)
|
||||
return tensor_dict
|
||||
@ -209,6 +220,81 @@ class StatefulModelInput(BroadcastableModelInput):
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
pythonized=False))
|
||||
|
||||
def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool):
|
||||
"""
|
||||
sampling_metadata.selected_token_indices is constructed for the
|
||||
first-step in Multi-Step. However, when chunked-prefill is enabled with
|
||||
multi-step, the scheduled prompts are fully processed in the
|
||||
first-step and are processed as decodes in the rest of the steps.
|
||||
This function updates the sampling_metadata.selected_token_indices
|
||||
to account for this conversion.
|
||||
|
||||
Example:
|
||||
Let 2 prompts and 2 decodes be scheduled together. Let the
|
||||
num-tokens to process for the 2 prompts be 5 and 8 respectively.
|
||||
|
||||
In that case, sampling_metadata.sampled_token_indices will be,
|
||||
[4, 12, 13, 14] as it is constructed for the first-step in
|
||||
multi-step.
|
||||
However, the prompts turns to decodes after the first-step
|
||||
and the num-tokens for the previously-prompt sequences will
|
||||
be 1 and 1 as they are decodes now. The self.sampled_token_indices
|
||||
must be updated to [0,1,2,3].
|
||||
"""
|
||||
assert self.current_step == 1 and self.num_single_step_prefills > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
return
|
||||
|
||||
assert self.frozen_model_input is not None
|
||||
assert self.frozen_model_input.sampling_metadata is not None
|
||||
self.frozen_model_input.sampling_metadata.selected_token_indices = \
|
||||
async_tensor_h2d(list(range(self.num_queries)),
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
|
||||
"""
|
||||
Advancing the datastructures of StatefulModelInput::frozen_model_input
|
||||
is only required when prefills are scheduled with decodes to run in
|
||||
multi-step. This advancement/correction is required to account for
|
||||
the conversion of Prefills to Decodes after the first multi-step.
|
||||
"""
|
||||
if self.current_step != 1 or self.num_single_step_prefills == 0:
|
||||
return
|
||||
|
||||
assert self.frozen_model_input is not None
|
||||
fmi = self.frozen_model_input
|
||||
|
||||
# Truncate input_tokens
|
||||
assert fmi.input_tokens is not None
|
||||
assert fmi.input_tokens.shape[0] >= self.num_seqs
|
||||
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
|
||||
|
||||
# Update frozen_model_input::input_positons.
|
||||
assert fmi.input_positions is not None
|
||||
assert fmi.input_positions.shape[0] >= self.num_seqs
|
||||
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
|
||||
num_seqs]
|
||||
|
||||
# Assert unsupported
|
||||
assert fmi.lora_mapping is None
|
||||
assert fmi.lora_requests is not None
|
||||
assert len(fmi.lora_requests) == 0
|
||||
assert fmi.attn_metadata is not None
|
||||
assert fmi.prompt_adapter_mapping is None
|
||||
assert fmi.prompt_adapter_requests is not None
|
||||
assert len(fmi.prompt_adapter_requests) == 0
|
||||
assert fmi.multi_modal_kwargs is not None
|
||||
assert len(fmi.multi_modal_kwargs) == 0
|
||||
|
||||
self.frozen_model_input = dataclasses.replace(
|
||||
self.frozen_model_input,
|
||||
input_tokens=fmi_new_input_tokens,
|
||||
input_positions=fmi_new_input_positions)
|
||||
|
||||
self.maybe_advance_sampling_metadata(device, pin_memory)
|
||||
|
||||
|
||||
# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
|
||||
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
|
||||
@ -220,6 +306,19 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Check attention backend support.
|
||||
supported_attention_backends: List[str] = \
|
||||
_get_supported_attention_backends(
|
||||
self.scheduler_config.chunked_prefill_enabled)
|
||||
if self.attn_backend.get_name() not in supported_attention_backends:
|
||||
ms_config_str: str = "Multi-Step + Chunked-Prefill" \
|
||||
if self.scheduler_config.chunked_prefill_enabled \
|
||||
else "Multi-Step"
|
||||
raise ValueError(
|
||||
f"{ms_config_str} not supported for attention backend: "
|
||||
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
|
||||
f"to a value from {supported_attention_backends}.")
|
||||
|
||||
# uses the base model runner to execute the model and wraps it with
|
||||
# multi-step logic
|
||||
self._base_model_runner: GPUModelRunnerBase = base_model_runner
|
||||
@ -248,14 +347,25 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> StatefulModelInput:
|
||||
frozen_model_input = self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list, virtual_engine, finished_requests_ids)
|
||||
frozen_model_input: ModelInputForGPUWithSamplingMetadata = \
|
||||
self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
virtual_engine,
|
||||
finished_requests_ids)
|
||||
|
||||
assert frozen_model_input.query_lens is not None
|
||||
assert frozen_model_input.seq_lens is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
num_queries = len(frozen_model_input.query_lens)
|
||||
num_seqs = len(frozen_model_input.seq_lens)
|
||||
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||
|
||||
model_input = StatefulModelInput(
|
||||
frozen_model_input=frozen_model_input,
|
||||
num_seqs=len(frozen_model_input.seq_lens),
|
||||
num_queries=len(frozen_model_input.query_lens),
|
||||
)
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
num_single_step_prefills=num_single_step_prefills)
|
||||
|
||||
return model_input
|
||||
|
||||
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||
@ -265,7 +375,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
output_proc_callback()
|
||||
|
||||
cont = True
|
||||
for model_output in model_input.cached_outputs:
|
||||
for step_num, model_output in enumerate(model_input.cached_outputs):
|
||||
if not model_output.pythonized:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
@ -276,7 +386,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False)
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
|
||||
output_proc_callback()
|
||||
else:
|
||||
@ -292,9 +403,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
has_async_callback = output_proc_callback is not None
|
||||
|
||||
outputs = []
|
||||
for output_id in range(len(model_input.cached_outputs)):
|
||||
output = model_input.cached_outputs[output_id]
|
||||
is_last_step = output_id == len(model_input.cached_outputs) - 1
|
||||
for step_num, output in enumerate(model_input.cached_outputs):
|
||||
is_last_step = step_num == len(model_input.cached_outputs) - 1
|
||||
|
||||
# For non-async case:
|
||||
# -- We simply add the outputs
|
||||
@ -323,7 +433,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False)
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
else:
|
||||
outputs.append(output.sampler_output)
|
||||
else:
|
||||
@ -389,18 +500,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
model_input = self._advance_step(
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
output_proc_callback = None
|
||||
# frozen_model_input may have been updated
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
if model_input.base_output_proc_callback is None:
|
||||
assert frozen_model_input is not None
|
||||
model_input.base_output_proc_callback = \
|
||||
frozen_model_input.async_callback
|
||||
|
||||
if frozen_model_input.async_callback is not None:
|
||||
output_proc_callback = frozen_model_input.async_callback
|
||||
assert output_proc_callback is not None
|
||||
assert model_input.base_output_proc_callback is not None
|
||||
async_callback = functools.partial(
|
||||
self._async_process_outputs,
|
||||
model_input=model_input,
|
||||
output_proc_callback=output_proc_callback)
|
||||
output_proc_callback=model_input.base_output_proc_callback)
|
||||
|
||||
frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=async_callback)
|
||||
# Update the local instance
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# Execute the model
|
||||
@ -455,8 +575,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
|
||||
# Pythonize the output and block if needed since it is the last step
|
||||
if model_input.is_last_step:
|
||||
outputs = self._final_process_outputs(model_input,
|
||||
output_proc_callback)
|
||||
outputs = self._final_process_outputs(
|
||||
model_input, model_input.base_output_proc_callback)
|
||||
self.pythonization_cache.reset()
|
||||
return outputs
|
||||
|
||||
@ -484,11 +604,14 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInput,
|
||||
out: SamplerOutput) -> StatefulModelInput:
|
||||
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
|
||||
raise ValueError(
|
||||
f"Multi-step not supported for attention backend: "
|
||||
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
|
||||
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
|
||||
|
||||
model_input.maybe_advance_frozen_model_input(self.device,
|
||||
self.pin_memory)
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.input_tokens is not None
|
||||
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
|
||||
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||
num_seqs = model_input.num_seqs
|
||||
@ -498,13 +621,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
attn_metadata = frozen_model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
|
||||
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
|
||||
model_input.num_single_step_prefills != 0
|
||||
attn_metadata.advance_step(
|
||||
frozen_model_input,
|
||||
sampled_token_ids,
|
||||
self.block_size,
|
||||
num_seqs,
|
||||
num_queries,
|
||||
)
|
||||
turn_prefills_into_decodes=turn_prefills_into_decodes)
|
||||
|
||||
return model_input
|
||||
|
||||
|
@ -76,8 +76,9 @@ class MultiStepWorker(Worker):
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
# clear the cached decode metadata so that it can be recomputed on
|
||||
# the workers
|
||||
# clear the cached metadata so that it can be recomputed on
|
||||
# the workers.
|
||||
frozen_model_input.attn_metadata._cached_prefill_metadata = None
|
||||
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
||||
|
||||
model_input.is_first_multi_step = is_first_multi_step
|
||||
|
Loading…
x
Reference in New Issue
Block a user