[Bugfix] Fix async postprocessor in case of preemption (#8267)

This commit is contained in:
Alexander Matveev 2024-09-08 00:01:51 -04:00 committed by GitHub
parent cfe712bf1a
commit 4ef41b8476
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 113 deletions

View File

@ -537,13 +537,6 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out swapped_out: List[SequenceGroup] = ret.swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running running_queue = self.running
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
@ -552,6 +545,7 @@ class Scheduler:
seq_group, SequenceStatus.RUNNING, enable_chunking, budget) seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
if num_running_tokens == 0: if num_running_tokens == 0:
# No budget => Stop
break break
running_queue.popleft() running_queue.popleft()
@ -565,18 +559,8 @@ class Scheduler:
self._async_stopped.append(seq_group) self._async_stopped.append(seq_group)
continue continue
# With async postprocessor, when preemption kicks in, we need # NOTE(woosuk): Preemption happens only when there is no available
# first to drain the async postprocessor, so that all async # slot to keep all the sequence groups in the RUNNING state.
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp
while not self._can_append_slots(seq_group): while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id, budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens) num_running_tokens)
@ -588,24 +572,43 @@ class Scheduler:
and seq_group.lora_int_id in curr_loras): and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id) curr_loras.remove(seq_group.lora_int_id)
# Determine victim sequence
cont_loop = True
if running_queue: if running_queue:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence group.
victim_seq_group = running_queue.pop() victim_seq_group = running_queue.pop()
else:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group = seq_group
cont_loop = False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt = True
if self.use_async_output_proc:
assert self.output_proc_callback is not None
self.output_proc_callback(
request_id=victim_seq_group.request_id)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if victim_seq_group.is_finished():
self._free_finished_seq_group(victim_seq_group)
do_preempt = False
# Do preemption
if do_preempt:
preempted_mode = self._preempt(victim_seq_group, preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out) blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE: if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group) preempted.append(victim_seq_group)
else: else:
swapped_out.append(victim_seq_group) swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted. if not cont_loop:
# Preempt the current sequence group.
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)
break break
else: else:
self._append_slots(seq_group, blocks_to_copy) self._append_slots(seq_group, blocks_to_copy)
@ -1264,22 +1267,26 @@ class Scheduler:
if seq.is_finished(): if seq.is_finished():
self.free_seq(seq) self.free_seq(seq)
def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque() remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running: for seq_group in self.running:
if seq_group.is_finished(): self._free_finished_seq_group(seq_group)
# Free cross-attention block table, if it exists if not seq_group.is_finished():
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group) remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
self.running = remaining self.running = remaining
# Handle async stopped sequence groups # Handle async stopped sequence groups

View File

@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine] virtual_engine]
# Execute the model. # Execute the model.
output = await self.model_executor.execute_model_async( outputs = await self.model_executor.execute_model_async(
execute_model_req) execute_model_req)
# we need to do this here so that last step's sampled_token_ids can # we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output) self._update_cached_scheduler_output(virtual_engine, outputs)
else: else:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
output = [] outputs = []
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[ self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState() virtual_engine] = SchedulerOutputState()
is_async = allow_async_output_proc ctx.append_output(outputs=outputs,
is_last_step = True seq_group_metadata_list=seq_group_metadata_list,
ctx.output_queue.append( scheduler_outputs=scheduler_outputs,
(output, seq_group_metadata_list, scheduler_outputs, is_async, is_async=allow_async_output_proc,
is_last_step)) is_last_step=True)
if output and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len( assert len(
output outputs
) == 1, "Async postprocessor expects only a single output set" ) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step( self._advance_to_next_step(
output[0], seq_group_metadata_list, outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc: if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, outputs)
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)

View File

@ -2,9 +2,9 @@ import functools
import time import time
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional) Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union from typing import Set, Tuple, Type, Union
@ -90,17 +90,36 @@ class SchedulerOutputState:
last_output: Optional[SamplerOutput] = None last_output: Optional[SamplerOutput] = None
@dataclass class OutputData(NamedTuple):
outputs: List[SamplerOutput]
seq_group_metadata_list: List[SequenceGroupMetadata]
scheduler_outputs: SchedulerOutputs
is_async: bool
is_last_step: bool
skip: List[int]
class SchedulerContext: class SchedulerContext:
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata], SchedulerOutputs, def __init__(self):
bool, self.output_queue: Deque[OutputData] = deque()
bool]] = field(default_factory=lambda: deque()) self.request_outputs: List[Union[RequestOutput,
request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
EmbeddingRequestOutput]] = field( self.seq_group_metadata_list: Optional[
default_factory=lambda: []) List[SequenceGroupMetadata]] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=is_async,
is_last_step=is_last_step,
skip=[]))
class LLMEngine: class LLMEngine:
@ -1246,23 +1265,15 @@ class LLMEngine:
return return
def _process_model_outputs(self, ctx: SchedulerContext) -> None: def _process_model_outputs(self,
"""Apply the model output to the sequences in the scheduled seq groups. ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
"""Apply the model output to the sequences in the scheduled seq groups
and return responses.
virtual_engine: The engine id to operate on ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
""" """
now = time.time() now = time.time()
@ -1270,9 +1281,14 @@ class LLMEngine:
return None return None
# Get pending async postprocessor # Get pending async postprocessor
(outputs, seq_group_metadata_list, scheduler_outputs, is_async, if request_id:
is_last_step) = ctx.output_queue.popleft() # When we process only one request, no pop is required
assert outputs is not None # (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue.popleft()
# Sanity check # Sanity check
assert len(seq_group_metadata_list) == len( assert len(seq_group_metadata_list) == len(
@ -1286,9 +1302,30 @@ class LLMEngine:
else: else:
outputs_by_sequence_group = outputs outputs_by_sequence_group = outputs
# Determine the requests we need to operate on
if request_id:
indices = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
if seq_group_meta.request_id == request_id:
assert i not in skip # Cannot be called twice
indices.append(i)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if not indices:
return
else:
indices = range(len(seq_group_metadata_list)) # type: ignore
finished_before: List[int] = [] finished_before: List[int] = []
finished_now: List[int] = [] finished_now: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list): for i in indices:
if i in skip:
continue
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
@ -1343,6 +1380,18 @@ class LLMEngine:
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if request_id:
assert len(indices) == 1
skip.append(indices[0])
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return
# Free currently finished requests # Free currently finished requests
if finished_now: if finished_now:
for scheduler in self.scheduler: for scheduler in self.scheduler:
@ -1354,17 +1403,16 @@ class LLMEngine:
if (finished_now if (finished_now
and self.process_request_outputs_callback is not None): and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs) self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return return
# Create the outputs # Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list for i in indices:
# must match with the indices if i in skip or i in finished_before or i in finished_now:
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
if i in finished_before or i in finished_now:
continue # Avoids double processing continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished() if (seq_group.is_finished()
@ -1380,6 +1428,7 @@ class LLMEngine:
if (ctx.request_outputs if (ctx.request_outputs
and self.process_request_outputs_callback is not None): and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs) self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
# For async case, we need to record the stats here. # For async case, we need to record the stats here.
# For non-async case, the stats are done in the # For non-async case, the stats are done in the
@ -1548,20 +1597,20 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
output = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output) self._update_cached_scheduler_output(virtual_engine, outputs)
else: else:
# Nothing scheduled => If there is pending async postprocessor, # Nothing scheduled => If there is pending async postprocessor,
# then finish it here. # then finish it here.
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
output = [] outputs = []
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@ -1574,18 +1623,18 @@ class LLMEngine:
self.cached_scheduler_outputs[0] = SchedulerOutputState() self.cached_scheduler_outputs[0] = SchedulerOutputState()
# Add results to the output_queue # Add results to the output_queue
is_async = allow_async_output_proc ctx.append_output(outputs=outputs,
is_last_step = True seq_group_metadata_list=seq_group_metadata_list,
ctx.output_queue.append( scheduler_outputs=scheduler_outputs,
(output, seq_group_metadata_list, scheduler_outputs, is_async, is_async=allow_async_output_proc,
is_last_step)) is_last_step=True)
if output and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len(output) == 1, ( assert len(outputs) == 1, (
"Async postprocessor expects only a single output set") "Async postprocessor expects only a single output set")
self._advance_to_next_step( self._advance_to_next_step(
output[0], seq_group_metadata_list, outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path # Check if need to run the usual non-async path
@ -1593,7 +1642,7 @@ class LLMEngine:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, outputs)
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)

View File

@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self.pinned_sampled_token_ids) self.pinned_sampled_token_ids)
if model_output.pythonized: if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"] ctx = output_proc_callback.keywords["ctx"]
is_async = False ctx.append_output(
is_last_step = False outputs=[model_output.sampler_output],
ctx.output_queue.append( seq_group_metadata_list=ctx.seq_group_metadata_list,
([model_output.sampler_output scheduler_outputs=ctx.scheduler_outputs,
], ctx.seq_group_metadata_list, is_async=False,
ctx.scheduler_outputs, is_async, is_last_step)) is_last_step=False)
output_proc_callback() output_proc_callback()
else: else:
cont = False cont = False
@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if not is_last_step: if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore "ctx"] # type: ignore
is_async = False ctx.append_output(
is_last_step = False outputs=[output.sampler_output],
ctx.output_queue.append( seq_group_metadata_list=ctx.
([output.sampler_output seq_group_metadata_list,
], ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs,
ctx.scheduler_outputs, is_async, is_last_step)) is_async=False,
is_last_step=False)
else: else:
outputs.append(output.sampler_output) outputs.append(output.sampler_output)
else: else: