[Bugfix] Fix async postprocessor in case of preemption (#8267)
This commit is contained in:
parent
cfe712bf1a
commit
4ef41b8476
@ -537,13 +537,6 @@ class Scheduler:
|
||||
preempted: List[SequenceGroup] = ret.preempted
|
||||
swapped_out: List[SequenceGroup] = ret.swapped_out
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
|
||||
# Store original running requests for the case of async + preemption
|
||||
if self.use_async_output_proc:
|
||||
orig_running = self.running.copy()
|
||||
|
||||
running_queue = self.running
|
||||
assert len(self._async_stopped) == 0
|
||||
while running_queue:
|
||||
@ -552,6 +545,7 @@ class Scheduler:
|
||||
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
||||
|
||||
if num_running_tokens == 0:
|
||||
# No budget => Stop
|
||||
break
|
||||
|
||||
running_queue.popleft()
|
||||
@ -565,18 +559,8 @@ class Scheduler:
|
||||
self._async_stopped.append(seq_group)
|
||||
continue
|
||||
|
||||
# With async postprocessor, when preemption kicks in, we need
|
||||
# first to drain the async postprocessor, so that all async
|
||||
# block_table freeing is applied before the preemption freeing
|
||||
# is applied.
|
||||
if self.use_async_output_proc and not self._can_append_slots(
|
||||
seq_group):
|
||||
tmp = self.running
|
||||
self.running = orig_running
|
||||
assert self.output_proc_callback is not None
|
||||
self.output_proc_callback()
|
||||
self.running = tmp
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available
|
||||
# slot to keep all the sequence groups in the RUNNING state.
|
||||
while not self._can_append_slots(seq_group):
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id,
|
||||
num_running_tokens)
|
||||
@ -588,24 +572,43 @@ class Scheduler:
|
||||
and seq_group.lora_int_id in curr_loras):
|
||||
curr_loras.remove(seq_group.lora_int_id)
|
||||
|
||||
# Determine victim sequence
|
||||
cont_loop = True
|
||||
if running_queue:
|
||||
# Preempt the lowest-priority sequence groups.
|
||||
# Preempt the lowest-priority sequence group.
|
||||
victim_seq_group = running_queue.pop()
|
||||
else:
|
||||
# No other sequence group can be preempted.
|
||||
# Preempt the current sequence group.
|
||||
# Note: This is also where we stop this loop
|
||||
# (since there is nothing else to preempt)
|
||||
victim_seq_group = seq_group
|
||||
cont_loop = False
|
||||
|
||||
# With async postprocessor, before preempting a sequence
|
||||
# we need to ensure it has no pending async postprocessor
|
||||
do_preempt = True
|
||||
if self.use_async_output_proc:
|
||||
assert self.output_proc_callback is not None
|
||||
self.output_proc_callback(
|
||||
request_id=victim_seq_group.request_id)
|
||||
|
||||
# It may be that the async pending "victim_seq_group"
|
||||
# becomes finished, in which case we simply free it.
|
||||
if victim_seq_group.is_finished():
|
||||
self._free_finished_seq_group(victim_seq_group)
|
||||
do_preempt = False
|
||||
|
||||
# Do preemption
|
||||
if do_preempt:
|
||||
preempted_mode = self._preempt(victim_seq_group,
|
||||
blocks_to_swap_out)
|
||||
if preempted_mode == PreemptionMode.RECOMPUTE:
|
||||
preempted.append(victim_seq_group)
|
||||
else:
|
||||
swapped_out.append(victim_seq_group)
|
||||
else:
|
||||
# No other sequence groups can be preempted.
|
||||
# Preempt the current sequence group.
|
||||
preempted_mode = self._preempt(seq_group,
|
||||
blocks_to_swap_out)
|
||||
if preempted_mode == PreemptionMode.RECOMPUTE:
|
||||
preempted.append(seq_group)
|
||||
else:
|
||||
swapped_out.append(seq_group)
|
||||
|
||||
if not cont_loop:
|
||||
break
|
||||
else:
|
||||
self._append_slots(seq_group, blocks_to_copy)
|
||||
@ -1264,22 +1267,26 @@ class Scheduler:
|
||||
if seq.is_finished():
|
||||
self.free_seq(seq)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
remaining: Deque[SequenceGroup] = deque()
|
||||
for seq_group in self.running:
|
||||
def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
if seq_group.is_finished():
|
||||
# Free cross-attention block table, if it exists
|
||||
self._free_seq_group_cross_attn_blocks(seq_group)
|
||||
|
||||
# Add the finished requests to the finished requests list.
|
||||
# This list will be used to update the Mamba cache in the
|
||||
# next step.
|
||||
self._finished_requests_ids.append(seq_group.request_id)
|
||||
else:
|
||||
remaining.append(seq_group)
|
||||
|
||||
# Free finished seqs
|
||||
self._free_finished_seqs(seq_group)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
remaining: Deque[SequenceGroup] = deque()
|
||||
for seq_group in self.running:
|
||||
self._free_finished_seq_group(seq_group)
|
||||
if not seq_group.is_finished():
|
||||
remaining.append(seq_group)
|
||||
|
||||
self.running = remaining
|
||||
|
||||
# Handle async stopped sequence groups
|
||||
|
@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
virtual_engine]
|
||||
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
outputs = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||
else:
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
output = []
|
||||
outputs = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
is_async = allow_async_output_proc
|
||||
is_last_step = True
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step))
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=allow_async_output_proc,
|
||||
is_last_step=True)
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
if outputs and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
outputs
|
||||
) == 1, "Async postprocessor expects only a single output set"
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
outputs[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
self.do_log_stats(scheduler_outputs, outputs)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
@ -2,9 +2,9 @@ import functools
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, Union
|
||||
|
||||
@ -90,17 +90,36 @@ class SchedulerOutputState:
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputData(NamedTuple):
|
||||
outputs: List[SamplerOutput]
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
scheduler_outputs: SchedulerOutputs
|
||||
is_async: bool
|
||||
is_last_step: bool
|
||||
skip: List[int]
|
||||
|
||||
|
||||
class SchedulerContext:
|
||||
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
||||
List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
bool,
|
||||
bool]] = field(default_factory=lambda: deque())
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = field(
|
||||
default_factory=lambda: [])
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
def __init__(self):
|
||||
self.output_queue: Deque[OutputData] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
self.seq_group_metadata_list: Optional[
|
||||
List[SequenceGroupMetadata]] = None
|
||||
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
def append_output(self, outputs: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||
is_last_step: bool):
|
||||
self.output_queue.append(
|
||||
OutputData(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=is_async,
|
||||
is_last_step=is_last_step,
|
||||
skip=[]))
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@ -1246,23 +1265,15 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
def _process_model_outputs(self,
|
||||
ctx: SchedulerContext,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups
|
||||
and return responses.
|
||||
|
||||
virtual_engine: The engine id to operate on
|
||||
ctx: The virtual engine context to work on
|
||||
request_id: If provided, then only this request is going to be processed
|
||||
|
||||
is_async: Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
|
||||
sampler_output: Used with multi-step execution to provide
|
||||
sampler_output of each step
|
||||
is_last_output: Used with multi-step execution to indicate
|
||||
the last step (of each multi-step group)
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
@ -1270,9 +1281,14 @@ class LLMEngine:
|
||||
return None
|
||||
|
||||
# Get pending async postprocessor
|
||||
if request_id:
|
||||
# When we process only one request, no pop is required
|
||||
# (since later we will process all of the rest)
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step) = ctx.output_queue.popleft()
|
||||
assert outputs is not None
|
||||
is_last_step, skip) = ctx.output_queue[0]
|
||||
else:
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step, skip) = ctx.output_queue.popleft()
|
||||
|
||||
# Sanity check
|
||||
assert len(seq_group_metadata_list) == len(
|
||||
@ -1286,9 +1302,30 @@ class LLMEngine:
|
||||
else:
|
||||
outputs_by_sequence_group = outputs
|
||||
|
||||
# Determine the requests we need to operate on
|
||||
if request_id:
|
||||
indices = []
|
||||
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
||||
if seq_group_meta.request_id == request_id:
|
||||
assert i not in skip # Cannot be called twice
|
||||
indices.append(i)
|
||||
break
|
||||
|
||||
# If the request_id was not found, then it means that
|
||||
# this is a new request that has no pending async
|
||||
# postprocessor
|
||||
if not indices:
|
||||
return
|
||||
else:
|
||||
indices = range(len(seq_group_metadata_list)) # type: ignore
|
||||
|
||||
finished_before: List[int] = []
|
||||
finished_now: List[int] = []
|
||||
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
||||
for i in indices:
|
||||
if i in skip:
|
||||
continue
|
||||
|
||||
seq_group_meta = seq_group_metadata_list[i]
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
@ -1343,6 +1380,18 @@ class LLMEngine:
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# When we process a single request, we skip it for the next time,
|
||||
# and invoke the request output callback (if there was final output)
|
||||
if request_id:
|
||||
assert len(indices) == 1
|
||||
skip.append(indices[0])
|
||||
|
||||
if (finished_now
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
return
|
||||
|
||||
# Free currently finished requests
|
||||
if finished_now:
|
||||
for scheduler in self.scheduler:
|
||||
@ -1354,17 +1403,16 @@ class LLMEngine:
|
||||
if (finished_now
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
return
|
||||
|
||||
# Create the outputs
|
||||
# Note: scheduled_seq_groups and seq_group_metadata_list
|
||||
# must match with the indices
|
||||
for i, scheduled_seq_group in enumerate(
|
||||
scheduler_outputs.scheduled_seq_groups):
|
||||
|
||||
if i in finished_before or i in finished_now:
|
||||
for i in indices:
|
||||
if i in skip or i in finished_before or i in finished_now:
|
||||
continue # Avoids double processing
|
||||
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
if (seq_group.is_finished()
|
||||
@ -1380,6 +1428,7 @@ class LLMEngine:
|
||||
if (ctx.request_outputs
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# For async case, we need to record the stats here.
|
||||
# For non-async case, the stats are done in the
|
||||
@ -1548,20 +1597,20 @@ class LLMEngine:
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
output = self.model_executor.execute_model(
|
||||
outputs = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
# We need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||
else:
|
||||
# Nothing scheduled => If there is pending async postprocessor,
|
||||
# then finish it here.
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
# No outputs in this case
|
||||
output = []
|
||||
outputs = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
@ -1574,18 +1623,18 @@ class LLMEngine:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
|
||||
# Add results to the output_queue
|
||||
is_async = allow_async_output_proc
|
||||
is_last_step = True
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step))
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=allow_async_output_proc,
|
||||
is_last_step=True)
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(output) == 1, (
|
||||
if outputs and allow_async_output_proc:
|
||||
assert len(outputs) == 1, (
|
||||
"Async postprocessor expects only a single output set")
|
||||
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
outputs[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# Check if need to run the usual non-async path
|
||||
@ -1593,7 +1642,7 @@ class LLMEngine:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
self.do_log_stats(scheduler_outputs, outputs)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
self.pinned_sampled_token_ids)
|
||||
if model_output.pythonized:
|
||||
ctx = output_proc_callback.keywords["ctx"]
|
||||
is_async = False
|
||||
is_last_step = False
|
||||
ctx.output_queue.append(
|
||||
([model_output.sampler_output
|
||||
], ctx.seq_group_metadata_list,
|
||||
ctx.scheduler_outputs, is_async, is_last_step))
|
||||
ctx.append_output(
|
||||
outputs=[model_output.sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False)
|
||||
|
||||
output_proc_callback()
|
||||
else:
|
||||
cont = False
|
||||
@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
if not is_last_step:
|
||||
ctx = output_proc_callback.keywords[ # type: ignore
|
||||
"ctx"] # type: ignore
|
||||
is_async = False
|
||||
is_last_step = False
|
||||
ctx.output_queue.append(
|
||||
([output.sampler_output
|
||||
], ctx.seq_group_metadata_list,
|
||||
ctx.scheduler_outputs, is_async, is_last_step))
|
||||
ctx.append_output(
|
||||
outputs=[output.sampler_output],
|
||||
seq_group_metadata_list=ctx.
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False)
|
||||
else:
|
||||
outputs.append(output.sampler_output)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user