[Core] Combine async postprocessor and multi-step (#7921)
This commit is contained in:
parent
f205c09854
commit
3f60f2244e
@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(example_prompts, model: str, tp_size: int,
|
||||
pp_size: int, eager_mode: int,
|
||||
num_scheduler_steps: int, num_prompts: int):
|
||||
num_scheduler_steps: int, num_prompts: int,
|
||||
is_async: bool):
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||
|
||||
# Disable output proc callback as its not supported
|
||||
# with multi-step right now
|
||||
ms_server_args += ["--disable-async-output-proc"]
|
||||
if not is_async:
|
||||
ms_server_args += ["--disable-async-output-proc"]
|
||||
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
|
@ -1107,10 +1107,7 @@ class Scheduler:
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
common_computed_block_nums = []
|
||||
|
||||
# TODO: Combine multi-step and async postprocessor
|
||||
allow_async_output_proc: bool = (
|
||||
self.use_async_output_proc
|
||||
and not self.scheduler_config.is_multi_step)
|
||||
allow_async_output_proc: bool = self.use_async_output_proc
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
|
@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
# Clear outputs on scheduler iteration start
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# If current scheduler iteration has no async postprocessor,
|
||||
# then we need first to drain the pending async postprocessor
|
||||
# before moving forward
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
# For async + multi-step, init the queue
|
||||
if use_async_and_multi_step:
|
||||
assert len(ctx.output_queue) == 0
|
||||
assert seq_group_metadata_list is not None
|
||||
ctx.output_queue.append(
|
||||
(None, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
assert not (self.scheduler_config.is_multi_step and \
|
||||
allow_async_output_proc)
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.async_callback = self.async_callback[
|
||||
virtual_engine]
|
||||
async_callback = self.async_callback_multi_step[
|
||||
virtual_engine] if use_async_and_multi_step \
|
||||
else self.async_callback[virtual_engine]
|
||||
|
||||
execute_model_req.async_callback = async_callback
|
||||
execute_model_req.use_async_and_multi_step = \
|
||||
use_async_and_multi_step
|
||||
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
if len(ctx.output_queue) > 0:
|
||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
@ -362,22 +378,25 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps
|
||||
# Clear the cache if we have finished all the steps
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
# Cache results in engine
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
if use_async_and_multi_step:
|
||||
# For async + multi-step, clear the queue
|
||||
ctx.output_queue.clear()
|
||||
else:
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
@ -390,7 +409,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
else:
|
||||
ctx.request_outputs = []
|
||||
# Multi-step case
|
||||
if use_async_and_multi_step:
|
||||
return []
|
||||
else:
|
||||
ctx.request_outputs = []
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
|
@ -91,7 +91,8 @@ class SchedulerOutputState:
|
||||
|
||||
@dataclass
|
||||
class SchedulerContext:
|
||||
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
|
||||
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
||||
List[SequenceGroupMetadata],
|
||||
SchedulerOutputs]] = field(
|
||||
default_factory=lambda: deque())
|
||||
|
||||
@ -432,6 +433,13 @@ class LLMEngine:
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.async_callback_multi_step = [
|
||||
functools.partial(self._process_model_outputs,
|
||||
virtual_engine=v_id,
|
||||
is_async=False)
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -1240,28 +1248,49 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(self, virtual_engine: int,
|
||||
is_async: bool) -> None:
|
||||
def _process_model_outputs(self,
|
||||
virtual_engine: int,
|
||||
is_async: bool,
|
||||
sampler_output: Optional[SamplerOutput] = None,
|
||||
is_last_output: bool = False) -> None:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
virtual_engine: The engine id to operate on
|
||||
|
||||
is_async: Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
|
||||
sampler_output: Used with multi-step execution to provide
|
||||
sampler_output of each step
|
||||
is_last_output: Used with multi-step execution to indicate
|
||||
the last step (of each multi-step group)
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
is_multi_step = sampler_output is not None
|
||||
|
||||
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
if len(ctx.output_queue) == 0:
|
||||
return None
|
||||
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = ctx.output_queue.popleft()
|
||||
if is_multi_step:
|
||||
# Async + multi-step case
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = ctx.output_queue[0]
|
||||
assert outputs is None
|
||||
outputs = [sampler_output]
|
||||
else:
|
||||
# Async standard case
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = ctx.output_queue.popleft()
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
# Sanity check
|
||||
assert len(seq_group_metadata_list) == len(
|
||||
@ -1320,7 +1349,11 @@ class LLMEngine:
|
||||
self.output_processor.process_outputs(seq_group, output,
|
||||
is_async)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
# For async + multi-step, free finished seqs and create outputs
|
||||
# only on the final step.
|
||||
if is_multi_step and not is_last_output:
|
||||
return
|
||||
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
@ -1328,7 +1361,7 @@ class LLMEngine:
|
||||
for i, _ in enumerate(seq_group_metadata_list):
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
if i in finished_before:
|
||||
if not is_multi_step and i in finished_before:
|
||||
continue # Avoids double processing
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
@ -1342,7 +1375,11 @@ class LLMEngine:
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
if is_async:
|
||||
# For async + multi-step, do stats only on the last output.
|
||||
# Otherwise, do stats if the execution is async
|
||||
do_stats = is_multi_step or is_async
|
||||
|
||||
if do_stats:
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
||||
|
||||
@ -1437,7 +1474,7 @@ class LLMEngine:
|
||||
"as performance will be severely degraded otherwise.")
|
||||
|
||||
# For llm_engine, there is no pipeline parallel support, so the engine
|
||||
# used is always 0
|
||||
# used is always 0.
|
||||
virtual_engine = 0
|
||||
|
||||
# These are cached outputs from previous iterations. None if on first
|
||||
@ -1447,6 +1484,10 @@ class LLMEngine:
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
@ -1462,11 +1503,22 @@ class LLMEngine:
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
# For async + multi-step, init the queue
|
||||
if use_async_and_multi_step:
|
||||
assert len(ctx.output_queue) == 0
|
||||
assert seq_group_metadata_list is not None
|
||||
ctx.output_queue.append(
|
||||
(None, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
@ -1478,9 +1530,6 @@ class LLMEngine:
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
assert not (self.scheduler_config.is_multi_step and \
|
||||
allow_async_output_proc)
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
@ -1505,8 +1554,13 @@ class LLMEngine:
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.async_callback = self.async_callback[
|
||||
virtual_engine]
|
||||
async_callback = self.async_callback_multi_step[
|
||||
virtual_engine] if use_async_and_multi_step \
|
||||
else self.async_callback[virtual_engine]
|
||||
|
||||
execute_model_req.async_callback = async_callback
|
||||
execute_model_req.use_async_and_multi_step = \
|
||||
use_async_and_multi_step
|
||||
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
@ -1518,7 +1572,7 @@ class LLMEngine:
|
||||
else:
|
||||
# Nothing scheduled => If there is pending async postprocessor,
|
||||
# then finish it here.
|
||||
if len(ctx.output_queue) > 0:
|
||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
@ -1535,18 +1589,23 @@ class LLMEngine:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
|
||||
# Add results to the output_queue
|
||||
# (for async or non-async postprocessing)
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
if use_async_and_multi_step:
|
||||
# For async + multi-step, clear the queue
|
||||
ctx.output_queue.clear()
|
||||
else:
|
||||
# Add results to the output_queue
|
||||
# (for async or non-async postprocessing)
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(output) == 1, ("Multi step decoding does not work "
|
||||
"with async output processing.")
|
||||
if output and allow_async_output_proc:
|
||||
assert len(output) == 1, (
|
||||
"Multi step decoding does not work "
|
||||
"with async output processing.")
|
||||
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# Check if need to run the usual non-async path
|
||||
if not allow_async_output_proc:
|
||||
@ -1560,7 +1619,10 @@ class LLMEngine:
|
||||
self.do_tracing(scheduler_outputs)
|
||||
else:
|
||||
# Multi-step case
|
||||
ctx.request_outputs = []
|
||||
if use_async_and_multi_step:
|
||||
return []
|
||||
else:
|
||||
ctx.request_outputs = []
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
|
@ -1295,6 +1295,7 @@ class ExecuteModelRequest(
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async callback
|
||||
async_callback: Optional[Callable] = None
|
||||
use_async_and_multi_step: bool = False
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
@ -1341,4 +1342,5 @@ class ExecuteModelRequest(
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
async_callback=self.async_callback)
|
||||
async_callback=self.async_callback,
|
||||
use_async_and_multi_step=self.use_async_and_multi_step)
|
||||
|
@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
virtual_engine: int = 0
|
||||
async_callback: Optional[Callable] = None
|
||||
use_async_and_multi_step: bool = False
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
|
@ -1,5 +1,7 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
@ -215,6 +217,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
)
|
||||
return model_input
|
||||
|
||||
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Callable):
|
||||
# Proceed with pythonization and output_proc in order.
|
||||
# Stop on the first one that fails to pythonize
|
||||
cont = True
|
||||
for model_output in model_input.cached_outputs:
|
||||
if not model_output.pythonized:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
if model_output.pythonized:
|
||||
output_proc_callback(
|
||||
sampler_output=model_output.sampler_output)
|
||||
else:
|
||||
cont = False
|
||||
|
||||
if not cont:
|
||||
break
|
||||
|
||||
def _final_process_outputs(self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Optional[Callable]):
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
outputs = []
|
||||
for output_id in range(len(model_input.cached_outputs)):
|
||||
is_last_output = output_id == len(model_input.cached_outputs) - 1
|
||||
|
||||
output = model_input.cached_outputs[output_id]
|
||||
if not output.pythonized:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
if model_input.frozen_model_input.use_async_and_multi_step:
|
||||
assert output_proc_callback is not None
|
||||
output_proc_callback(sampler_output=output.sampler_output,
|
||||
is_last_output=is_last_output)
|
||||
|
||||
outputs.append(output.sampler_output)
|
||||
|
||||
return outputs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -271,6 +313,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
model_input = self._advance_step(
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
output_proc_callback = None
|
||||
if frozen_model_input.use_async_and_multi_step:
|
||||
output_proc_callback = frozen_model_input.async_callback
|
||||
assert output_proc_callback is not None
|
||||
async_callback = functools.partial(
|
||||
self._async_process_outputs,
|
||||
model_input=model_input,
|
||||
output_proc_callback=output_proc_callback)
|
||||
|
||||
frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=async_callback)
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# Execute the model
|
||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||
kv_caches,
|
||||
@ -301,9 +357,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
output[0].logprobs = None
|
||||
# Pythonize the output if CPU is ahead and the previous step is
|
||||
# ready.
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
if not frozen_model_input.use_async_and_multi_step:
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input,
|
||||
self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
model_input.current_step += 1
|
||||
|
||||
@ -316,11 +374,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
|
||||
# Pythonize the output and block if needed since it is the last step
|
||||
if model_input.is_last_step:
|
||||
outputs = []
|
||||
for output in model_input.cached_outputs:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
outputs.append(output.sampler_output)
|
||||
outputs = self._final_process_outputs(model_input,
|
||||
output_proc_callback)
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
|
@ -1,3 +1,4 @@
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@ -61,6 +62,13 @@ class MultiStepWorker(Worker):
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=execute_model_req.async_callback,
|
||||
use_async_and_multi_step=execute_model_req.
|
||||
use_async_and_multi_step)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
multi_step_state = self.multi_step_states[virtual_engine]
|
||||
|
Loading…
x
Reference in New Issue
Block a user