[Bugfix] Fix async postprocessor in case of preemption (#8267)
This commit is contained in:
parent
cfe712bf1a
commit
4ef41b8476
@ -537,13 +537,6 @@ class Scheduler:
|
|||||||
preempted: List[SequenceGroup] = ret.preempted
|
preempted: List[SequenceGroup] = ret.preempted
|
||||||
swapped_out: List[SequenceGroup] = ret.swapped_out
|
swapped_out: List[SequenceGroup] = ret.swapped_out
|
||||||
|
|
||||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
|
||||||
# to keep all the sequence groups in the RUNNING state.
|
|
||||||
|
|
||||||
# Store original running requests for the case of async + preemption
|
|
||||||
if self.use_async_output_proc:
|
|
||||||
orig_running = self.running.copy()
|
|
||||||
|
|
||||||
running_queue = self.running
|
running_queue = self.running
|
||||||
assert len(self._async_stopped) == 0
|
assert len(self._async_stopped) == 0
|
||||||
while running_queue:
|
while running_queue:
|
||||||
@ -552,6 +545,7 @@ class Scheduler:
|
|||||||
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
||||||
|
|
||||||
if num_running_tokens == 0:
|
if num_running_tokens == 0:
|
||||||
|
# No budget => Stop
|
||||||
break
|
break
|
||||||
|
|
||||||
running_queue.popleft()
|
running_queue.popleft()
|
||||||
@ -565,18 +559,8 @@ class Scheduler:
|
|||||||
self._async_stopped.append(seq_group)
|
self._async_stopped.append(seq_group)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# With async postprocessor, when preemption kicks in, we need
|
# NOTE(woosuk): Preemption happens only when there is no available
|
||||||
# first to drain the async postprocessor, so that all async
|
# slot to keep all the sequence groups in the RUNNING state.
|
||||||
# block_table freeing is applied before the preemption freeing
|
|
||||||
# is applied.
|
|
||||||
if self.use_async_output_proc and not self._can_append_slots(
|
|
||||||
seq_group):
|
|
||||||
tmp = self.running
|
|
||||||
self.running = orig_running
|
|
||||||
assert self.output_proc_callback is not None
|
|
||||||
self.output_proc_callback()
|
|
||||||
self.running = tmp
|
|
||||||
|
|
||||||
while not self._can_append_slots(seq_group):
|
while not self._can_append_slots(seq_group):
|
||||||
budget.subtract_num_batched_tokens(seq_group.request_id,
|
budget.subtract_num_batched_tokens(seq_group.request_id,
|
||||||
num_running_tokens)
|
num_running_tokens)
|
||||||
@ -588,24 +572,43 @@ class Scheduler:
|
|||||||
and seq_group.lora_int_id in curr_loras):
|
and seq_group.lora_int_id in curr_loras):
|
||||||
curr_loras.remove(seq_group.lora_int_id)
|
curr_loras.remove(seq_group.lora_int_id)
|
||||||
|
|
||||||
|
# Determine victim sequence
|
||||||
|
cont_loop = True
|
||||||
if running_queue:
|
if running_queue:
|
||||||
# Preempt the lowest-priority sequence groups.
|
# Preempt the lowest-priority sequence group.
|
||||||
victim_seq_group = running_queue.pop()
|
victim_seq_group = running_queue.pop()
|
||||||
|
else:
|
||||||
|
# No other sequence group can be preempted.
|
||||||
|
# Preempt the current sequence group.
|
||||||
|
# Note: This is also where we stop this loop
|
||||||
|
# (since there is nothing else to preempt)
|
||||||
|
victim_seq_group = seq_group
|
||||||
|
cont_loop = False
|
||||||
|
|
||||||
|
# With async postprocessor, before preempting a sequence
|
||||||
|
# we need to ensure it has no pending async postprocessor
|
||||||
|
do_preempt = True
|
||||||
|
if self.use_async_output_proc:
|
||||||
|
assert self.output_proc_callback is not None
|
||||||
|
self.output_proc_callback(
|
||||||
|
request_id=victim_seq_group.request_id)
|
||||||
|
|
||||||
|
# It may be that the async pending "victim_seq_group"
|
||||||
|
# becomes finished, in which case we simply free it.
|
||||||
|
if victim_seq_group.is_finished():
|
||||||
|
self._free_finished_seq_group(victim_seq_group)
|
||||||
|
do_preempt = False
|
||||||
|
|
||||||
|
# Do preemption
|
||||||
|
if do_preempt:
|
||||||
preempted_mode = self._preempt(victim_seq_group,
|
preempted_mode = self._preempt(victim_seq_group,
|
||||||
blocks_to_swap_out)
|
blocks_to_swap_out)
|
||||||
if preempted_mode == PreemptionMode.RECOMPUTE:
|
if preempted_mode == PreemptionMode.RECOMPUTE:
|
||||||
preempted.append(victim_seq_group)
|
preempted.append(victim_seq_group)
|
||||||
else:
|
else:
|
||||||
swapped_out.append(victim_seq_group)
|
swapped_out.append(victim_seq_group)
|
||||||
else:
|
|
||||||
# No other sequence groups can be preempted.
|
if not cont_loop:
|
||||||
# Preempt the current sequence group.
|
|
||||||
preempted_mode = self._preempt(seq_group,
|
|
||||||
blocks_to_swap_out)
|
|
||||||
if preempted_mode == PreemptionMode.RECOMPUTE:
|
|
||||||
preempted.append(seq_group)
|
|
||||||
else:
|
|
||||||
swapped_out.append(seq_group)
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self._append_slots(seq_group, blocks_to_copy)
|
self._append_slots(seq_group, blocks_to_copy)
|
||||||
@ -1264,22 +1267,26 @@ class Scheduler:
|
|||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
self.free_seq(seq)
|
self.free_seq(seq)
|
||||||
|
|
||||||
|
def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
|
if seq_group.is_finished():
|
||||||
|
# Free cross-attention block table, if it exists
|
||||||
|
self._free_seq_group_cross_attn_blocks(seq_group)
|
||||||
|
|
||||||
|
# Add the finished requests to the finished requests list.
|
||||||
|
# This list will be used to update the Mamba cache in the
|
||||||
|
# next step.
|
||||||
|
self._finished_requests_ids.append(seq_group.request_id)
|
||||||
|
|
||||||
|
# Free finished seqs
|
||||||
|
self._free_finished_seqs(seq_group)
|
||||||
|
|
||||||
def free_finished_seq_groups(self) -> None:
|
def free_finished_seq_groups(self) -> None:
|
||||||
remaining: Deque[SequenceGroup] = deque()
|
remaining: Deque[SequenceGroup] = deque()
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
if seq_group.is_finished():
|
self._free_finished_seq_group(seq_group)
|
||||||
# Free cross-attention block table, if it exists
|
if not seq_group.is_finished():
|
||||||
self._free_seq_group_cross_attn_blocks(seq_group)
|
|
||||||
# Add the finished requests to the finished requests list.
|
|
||||||
# This list will be used to update the Mamba cache in the
|
|
||||||
# next step.
|
|
||||||
self._finished_requests_ids.append(seq_group.request_id)
|
|
||||||
else:
|
|
||||||
remaining.append(seq_group)
|
remaining.append(seq_group)
|
||||||
|
|
||||||
# Free finished seqs
|
|
||||||
self._free_finished_seqs(seq_group)
|
|
||||||
|
|
||||||
self.running = remaining
|
self.running = remaining
|
||||||
|
|
||||||
# Handle async stopped sequence groups
|
# Handle async stopped sequence groups
|
||||||
|
@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
virtual_engine]
|
virtual_engine]
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = await self.model_executor.execute_model_async(
|
outputs = await self.model_executor.execute_model_async(
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
|
|
||||||
# we need to do this here so that last step's sampled_token_ids can
|
# we need to do this here so that last step's sampled_token_ids can
|
||||||
# be passed to the next iteration for PP.
|
# be passed to the next iteration for PP.
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self._update_cached_scheduler_output(virtual_engine, output)
|
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||||
else:
|
else:
|
||||||
if len(ctx.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(ctx=ctx)
|
self._process_model_outputs(ctx=ctx)
|
||||||
output = []
|
outputs = []
|
||||||
|
|
||||||
# Finish the current step for all the sequence groups.
|
# Finish the current step for all the sequence groups.
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
self.cached_scheduler_outputs[
|
self.cached_scheduler_outputs[
|
||||||
virtual_engine] = SchedulerOutputState()
|
virtual_engine] = SchedulerOutputState()
|
||||||
|
|
||||||
is_async = allow_async_output_proc
|
ctx.append_output(outputs=outputs,
|
||||||
is_last_step = True
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
ctx.output_queue.append(
|
scheduler_outputs=scheduler_outputs,
|
||||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
is_async=allow_async_output_proc,
|
||||||
is_last_step))
|
is_last_step=True)
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if outputs and allow_async_output_proc:
|
||||||
assert len(
|
assert len(
|
||||||
output
|
outputs
|
||||||
) == 1, "Async postprocessor expects only a single output set"
|
) == 1, "Async postprocessor expects only a single output set"
|
||||||
self._advance_to_next_step(
|
self._advance_to_next_step(
|
||||||
output[0], seq_group_metadata_list,
|
outputs[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
self._process_model_outputs(ctx=ctx)
|
self._process_model_outputs(ctx=ctx)
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, outputs)
|
||||||
|
|
||||||
# Tracing
|
# Tracing
|
||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
|
@ -2,9 +2,9 @@ import functools
|
|||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
||||||
Mapping, Optional)
|
Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Tuple, Type, Union
|
from typing import Set, Tuple, Type, Union
|
||||||
|
|
||||||
@ -90,17 +90,36 @@ class SchedulerOutputState:
|
|||||||
last_output: Optional[SamplerOutput] = None
|
last_output: Optional[SamplerOutput] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class OutputData(NamedTuple):
|
||||||
|
outputs: List[SamplerOutput]
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
|
scheduler_outputs: SchedulerOutputs
|
||||||
|
is_async: bool
|
||||||
|
is_last_step: bool
|
||||||
|
skip: List[int]
|
||||||
|
|
||||||
|
|
||||||
class SchedulerContext:
|
class SchedulerContext:
|
||||||
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
|
||||||
List[SequenceGroupMetadata], SchedulerOutputs,
|
def __init__(self):
|
||||||
bool,
|
self.output_queue: Deque[OutputData] = deque()
|
||||||
bool]] = field(default_factory=lambda: deque())
|
self.request_outputs: List[Union[RequestOutput,
|
||||||
request_outputs: List[Union[RequestOutput,
|
EmbeddingRequestOutput]] = []
|
||||||
EmbeddingRequestOutput]] = field(
|
self.seq_group_metadata_list: Optional[
|
||||||
default_factory=lambda: [])
|
List[SequenceGroupMetadata]] = None
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
|
||||||
|
def append_output(self, outputs: List[SamplerOutput],
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||||
|
is_last_step: bool):
|
||||||
|
self.output_queue.append(
|
||||||
|
OutputData(outputs=outputs,
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
scheduler_outputs=scheduler_outputs,
|
||||||
|
is_async=is_async,
|
||||||
|
is_last_step=is_last_step,
|
||||||
|
skip=[]))
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
@ -1246,23 +1265,15 @@ class LLMEngine:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
|
def _process_model_outputs(self,
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
ctx: SchedulerContext,
|
||||||
|
request_id: Optional[str] = None) -> None:
|
||||||
|
"""Apply the model output to the sequences in the scheduled seq groups
|
||||||
|
and return responses.
|
||||||
|
|
||||||
virtual_engine: The engine id to operate on
|
ctx: The virtual engine context to work on
|
||||||
|
request_id: If provided, then only this request is going to be processed
|
||||||
|
|
||||||
is_async: Indicates whether this postprocessor runs in
|
|
||||||
parallel with the GPU forward pass and is processing
|
|
||||||
tokens from the previous step. If this is true, then
|
|
||||||
no tokens need to be appended since it is already done
|
|
||||||
externally (before the next schedule() call)
|
|
||||||
|
|
||||||
sampler_output: Used with multi-step execution to provide
|
|
||||||
sampler_output of each step
|
|
||||||
is_last_output: Used with multi-step execution to indicate
|
|
||||||
the last step (of each multi-step group)
|
|
||||||
|
|
||||||
Returns RequestOutputs that can be returned to the client.
|
|
||||||
"""
|
"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
@ -1270,9 +1281,14 @@ class LLMEngine:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Get pending async postprocessor
|
# Get pending async postprocessor
|
||||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
if request_id:
|
||||||
is_last_step) = ctx.output_queue.popleft()
|
# When we process only one request, no pop is required
|
||||||
assert outputs is not None
|
# (since later we will process all of the rest)
|
||||||
|
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||||
|
is_last_step, skip) = ctx.output_queue[0]
|
||||||
|
else:
|
||||||
|
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||||
|
is_last_step, skip) = ctx.output_queue.popleft()
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(seq_group_metadata_list) == len(
|
assert len(seq_group_metadata_list) == len(
|
||||||
@ -1286,9 +1302,30 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
outputs_by_sequence_group = outputs
|
outputs_by_sequence_group = outputs
|
||||||
|
|
||||||
|
# Determine the requests we need to operate on
|
||||||
|
if request_id:
|
||||||
|
indices = []
|
||||||
|
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
||||||
|
if seq_group_meta.request_id == request_id:
|
||||||
|
assert i not in skip # Cannot be called twice
|
||||||
|
indices.append(i)
|
||||||
|
break
|
||||||
|
|
||||||
|
# If the request_id was not found, then it means that
|
||||||
|
# this is a new request that has no pending async
|
||||||
|
# postprocessor
|
||||||
|
if not indices:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
indices = range(len(seq_group_metadata_list)) # type: ignore
|
||||||
|
|
||||||
finished_before: List[int] = []
|
finished_before: List[int] = []
|
||||||
finished_now: List[int] = []
|
finished_now: List[int] = []
|
||||||
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
for i in indices:
|
||||||
|
if i in skip:
|
||||||
|
continue
|
||||||
|
|
||||||
|
seq_group_meta = seq_group_metadata_list[i]
|
||||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
@ -1343,6 +1380,18 @@ class LLMEngine:
|
|||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
|
# When we process a single request, we skip it for the next time,
|
||||||
|
# and invoke the request output callback (if there was final output)
|
||||||
|
if request_id:
|
||||||
|
assert len(indices) == 1
|
||||||
|
skip.append(indices[0])
|
||||||
|
|
||||||
|
if (finished_now
|
||||||
|
and self.process_request_outputs_callback is not None):
|
||||||
|
self.process_request_outputs_callback(ctx.request_outputs)
|
||||||
|
ctx.request_outputs.clear()
|
||||||
|
return
|
||||||
|
|
||||||
# Free currently finished requests
|
# Free currently finished requests
|
||||||
if finished_now:
|
if finished_now:
|
||||||
for scheduler in self.scheduler:
|
for scheduler in self.scheduler:
|
||||||
@ -1354,17 +1403,16 @@ class LLMEngine:
|
|||||||
if (finished_now
|
if (finished_now
|
||||||
and self.process_request_outputs_callback is not None):
|
and self.process_request_outputs_callback is not None):
|
||||||
self.process_request_outputs_callback(ctx.request_outputs)
|
self.process_request_outputs_callback(ctx.request_outputs)
|
||||||
|
ctx.request_outputs.clear()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create the outputs
|
# Create the outputs
|
||||||
# Note: scheduled_seq_groups and seq_group_metadata_list
|
for i in indices:
|
||||||
# must match with the indices
|
if i in skip or i in finished_before or i in finished_now:
|
||||||
for i, scheduled_seq_group in enumerate(
|
|
||||||
scheduler_outputs.scheduled_seq_groups):
|
|
||||||
|
|
||||||
if i in finished_before or i in finished_now:
|
|
||||||
continue # Avoids double processing
|
continue # Avoids double processing
|
||||||
|
|
||||||
|
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.maybe_set_first_token_time(now)
|
seq_group.maybe_set_first_token_time(now)
|
||||||
if (seq_group.is_finished()
|
if (seq_group.is_finished()
|
||||||
@ -1380,6 +1428,7 @@ class LLMEngine:
|
|||||||
if (ctx.request_outputs
|
if (ctx.request_outputs
|
||||||
and self.process_request_outputs_callback is not None):
|
and self.process_request_outputs_callback is not None):
|
||||||
self.process_request_outputs_callback(ctx.request_outputs)
|
self.process_request_outputs_callback(ctx.request_outputs)
|
||||||
|
ctx.request_outputs.clear()
|
||||||
|
|
||||||
# For async case, we need to record the stats here.
|
# For async case, we need to record the stats here.
|
||||||
# For non-async case, the stats are done in the
|
# For non-async case, the stats are done in the
|
||||||
@ -1548,20 +1597,20 @@ class LLMEngine:
|
|||||||
execute_model_req.async_callback = self.async_callbacks[
|
execute_model_req.async_callback = self.async_callbacks[
|
||||||
virtual_engine]
|
virtual_engine]
|
||||||
|
|
||||||
output = self.model_executor.execute_model(
|
outputs = self.model_executor.execute_model(
|
||||||
execute_model_req=execute_model_req)
|
execute_model_req=execute_model_req)
|
||||||
|
|
||||||
# We need to do this here so that last step's sampled_token_ids can
|
# We need to do this here so that last step's sampled_token_ids can
|
||||||
# be passed to the next iteration for PP.
|
# be passed to the next iteration for PP.
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self._update_cached_scheduler_output(virtual_engine, output)
|
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||||
else:
|
else:
|
||||||
# Nothing scheduled => If there is pending async postprocessor,
|
# Nothing scheduled => If there is pending async postprocessor,
|
||||||
# then finish it here.
|
# then finish it here.
|
||||||
if len(ctx.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(ctx=ctx)
|
self._process_model_outputs(ctx=ctx)
|
||||||
# No outputs in this case
|
# No outputs in this case
|
||||||
output = []
|
outputs = []
|
||||||
|
|
||||||
# Finish the current step for all the sequence groups.
|
# Finish the current step for all the sequence groups.
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
@ -1574,18 +1623,18 @@ class LLMEngine:
|
|||||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||||
|
|
||||||
# Add results to the output_queue
|
# Add results to the output_queue
|
||||||
is_async = allow_async_output_proc
|
ctx.append_output(outputs=outputs,
|
||||||
is_last_step = True
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
ctx.output_queue.append(
|
scheduler_outputs=scheduler_outputs,
|
||||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
is_async=allow_async_output_proc,
|
||||||
is_last_step))
|
is_last_step=True)
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if outputs and allow_async_output_proc:
|
||||||
assert len(output) == 1, (
|
assert len(outputs) == 1, (
|
||||||
"Async postprocessor expects only a single output set")
|
"Async postprocessor expects only a single output set")
|
||||||
|
|
||||||
self._advance_to_next_step(
|
self._advance_to_next_step(
|
||||||
output[0], seq_group_metadata_list,
|
outputs[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
# Check if need to run the usual non-async path
|
# Check if need to run the usual non-async path
|
||||||
@ -1593,7 +1642,7 @@ class LLMEngine:
|
|||||||
self._process_model_outputs(ctx=ctx)
|
self._process_model_outputs(ctx=ctx)
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, outputs)
|
||||||
|
|
||||||
# Tracing
|
# Tracing
|
||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
|
@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
self.pinned_sampled_token_ids)
|
self.pinned_sampled_token_ids)
|
||||||
if model_output.pythonized:
|
if model_output.pythonized:
|
||||||
ctx = output_proc_callback.keywords["ctx"]
|
ctx = output_proc_callback.keywords["ctx"]
|
||||||
is_async = False
|
ctx.append_output(
|
||||||
is_last_step = False
|
outputs=[model_output.sampler_output],
|
||||||
ctx.output_queue.append(
|
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||||
([model_output.sampler_output
|
scheduler_outputs=ctx.scheduler_outputs,
|
||||||
], ctx.seq_group_metadata_list,
|
is_async=False,
|
||||||
ctx.scheduler_outputs, is_async, is_last_step))
|
is_last_step=False)
|
||||||
|
|
||||||
output_proc_callback()
|
output_proc_callback()
|
||||||
else:
|
else:
|
||||||
cont = False
|
cont = False
|
||||||
@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
if not is_last_step:
|
if not is_last_step:
|
||||||
ctx = output_proc_callback.keywords[ # type: ignore
|
ctx = output_proc_callback.keywords[ # type: ignore
|
||||||
"ctx"] # type: ignore
|
"ctx"] # type: ignore
|
||||||
is_async = False
|
ctx.append_output(
|
||||||
is_last_step = False
|
outputs=[output.sampler_output],
|
||||||
ctx.output_queue.append(
|
seq_group_metadata_list=ctx.
|
||||||
([output.sampler_output
|
seq_group_metadata_list,
|
||||||
], ctx.seq_group_metadata_list,
|
scheduler_outputs=ctx.scheduler_outputs,
|
||||||
ctx.scheduler_outputs, is_async, is_last_step))
|
is_async=False,
|
||||||
|
is_last_step=False)
|
||||||
else:
|
else:
|
||||||
outputs.append(output.sampler_output)
|
outputs.append(output.sampler_output)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user