[Bugfix] Fix async postprocessor in case of preemption (#8267)

This commit is contained in:
Alexander Matveev 2024-09-08 00:01:51 -04:00 committed by GitHub
parent cfe712bf1a
commit 4ef41b8476
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 113 deletions

View File

@ -537,13 +537,6 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@ -552,6 +545,7 @@ class Scheduler:
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
if num_running_tokens == 0:
# No budget => Stop
break
running_queue.popleft()
@ -565,18 +559,8 @@ class Scheduler:
self._async_stopped.append(seq_group)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp
# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
@ -588,24 +572,43 @@ class Scheduler:
and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id)
# Determine victim sequence
cont_loop = True
if running_queue:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence group.
victim_seq_group = running_queue.pop()
else:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group = seq_group
cont_loop = False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt = True
if self.use_async_output_proc:
assert self.output_proc_callback is not None
self.output_proc_callback(
request_id=victim_seq_group.request_id)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if victim_seq_group.is_finished():
self._free_finished_seq_group(victim_seq_group)
do_preempt = False
# Do preemption
if do_preempt:
preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group)
else:
swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)
if not cont_loop:
break
else:
self._append_slots(seq_group, blocks_to_copy)
@ -1264,22 +1267,26 @@ class Scheduler:
if seq.is_finished():
self.free_seq(seq)
def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
self._free_finished_seq_group(seq_group)
if not seq_group.is_finished():
remaining.append(seq_group)
self.running = remaining
# Handle async stopped sequence groups

View File

@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine]
# Execute the model.
output = await self.model_executor.execute_model_async(
outputs = await self.model_executor.execute_model_async(
execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
self._update_cached_scheduler_output(virtual_engine, outputs)
else:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []
outputs = []
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)
if output and allow_async_output_proc:
if outputs and allow_async_output_proc:
assert len(
output
outputs
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
self.do_log_stats(scheduler_outputs, outputs)
# Tracing
self.do_tracing(scheduler_outputs)

View File

@ -2,9 +2,9 @@ import functools
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional)
Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union
@ -90,17 +90,36 @@ class SchedulerOutputState:
last_output: Optional[SamplerOutput] = None
@dataclass
class OutputData(NamedTuple):
outputs: List[SamplerOutput]
seq_group_metadata_list: List[SequenceGroupMetadata]
scheduler_outputs: SchedulerOutputs
is_async: bool
is_last_step: bool
skip: List[int]
class SchedulerContext:
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata], SchedulerOutputs,
bool,
bool]] = field(default_factory=lambda: deque())
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = field(
default_factory=lambda: [])
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def __init__(self):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=is_async,
is_last_step=is_last_step,
skip=[]))
class LLMEngine:
@ -1246,23 +1265,15 @@ class LLMEngine:
return
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
def _process_model_outputs(self,
ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
"""Apply the model output to the sequences in the scheduled seq groups
and return responses.
virtual_engine: The engine id to operate on
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
@ -1270,9 +1281,14 @@ class LLMEngine:
return None
# Get pending async postprocessor
if request_id:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step) = ctx.output_queue.popleft()
assert outputs is not None
is_last_step, skip) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
@ -1286,9 +1302,30 @@ class LLMEngine:
else:
outputs_by_sequence_group = outputs
# Determine the requests we need to operate on
if request_id:
indices = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
if seq_group_meta.request_id == request_id:
assert i not in skip # Cannot be called twice
indices.append(i)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if not indices:
return
else:
indices = range(len(seq_group_metadata_list)) # type: ignore
finished_before: List[int] = []
finished_now: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
for i in indices:
if i in skip:
continue
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
@ -1343,6 +1380,18 @@ class LLMEngine:
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if request_id:
assert len(indices) == 1
skip.append(indices[0])
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return
# Free currently finished requests
if finished_now:
for scheduler in self.scheduler:
@ -1354,17 +1403,16 @@ class LLMEngine:
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
if i in finished_before or i in finished_now:
for i in indices:
if i in skip or i in finished_before or i in finished_now:
continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished()
@ -1380,6 +1428,7 @@ class LLMEngine:
if (ctx.request_outputs
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
@ -1548,20 +1597,20 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
output = self.model_executor.execute_model(
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
self._update_cached_scheduler_output(virtual_engine, outputs)
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
output = []
outputs = []
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@ -1574,18 +1623,18 @@ class LLMEngine:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# Add results to the output_queue
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)
if output and allow_async_output_proc:
assert len(output) == 1, (
if outputs and allow_async_output_proc:
assert len(outputs) == 1, (
"Async postprocessor expects only a single output set")
self._advance_to_next_step(
output[0], seq_group_metadata_list,
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path
@ -1593,7 +1642,7 @@ class LLMEngine:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
self.do_log_stats(scheduler_outputs, outputs)
# Tracing
self.do_tracing(scheduler_outputs)

View File

@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self.pinned_sampled_token_ids)
if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"]
is_async = False
is_last_step = False
ctx.output_queue.append(
([model_output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
ctx.append_output(
outputs=[model_output.sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
output_proc_callback()
else:
cont = False
@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
is_async = False
is_last_step = False
ctx.output_queue.append(
([output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
ctx.append_output(
outputs=[output.sampler_output],
seq_group_metadata_list=ctx.
seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
else:
outputs.append(output.sampler_output)
else: