[Core] Combine async postprocessor and multi-step (#7921)
This commit is contained in:
parent
f205c09854
commit
3f60f2244e
@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
|
|||||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
|
@pytest.mark.parametrize("is_async", [False, True])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multi_step(example_prompts, model: str, tp_size: int,
|
async def test_multi_step(example_prompts, model: str, tp_size: int,
|
||||||
pp_size: int, eager_mode: int,
|
pp_size: int, eager_mode: int,
|
||||||
num_scheduler_steps: int, num_prompts: int):
|
num_scheduler_steps: int, num_prompts: int,
|
||||||
|
is_async: bool):
|
||||||
|
|
||||||
prompts = example_prompts
|
prompts = example_prompts
|
||||||
if len(prompts) < num_prompts:
|
if len(prompts) < num_prompts:
|
||||||
@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
|
|||||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||||
|
|
||||||
# Disable output proc callback as its not supported
|
if not is_async:
|
||||||
# with multi-step right now
|
ms_server_args += ["--disable-async-output-proc"]
|
||||||
ms_server_args += ["--disable-async-output-proc"]
|
|
||||||
if eager_mode:
|
if eager_mode:
|
||||||
ms_server_args.append("--enforce-eager")
|
ms_server_args.append("--enforce-eager")
|
||||||
|
|
||||||
|
@ -1107,10 +1107,7 @@ class Scheduler:
|
|||||||
if not self.cache_config.enable_prefix_caching:
|
if not self.cache_config.enable_prefix_caching:
|
||||||
common_computed_block_nums = []
|
common_computed_block_nums = []
|
||||||
|
|
||||||
# TODO: Combine multi-step and async postprocessor
|
allow_async_output_proc: bool = self.use_async_output_proc
|
||||||
allow_async_output_proc: bool = (
|
|
||||||
self.use_async_output_proc
|
|
||||||
and not self.scheduler_config.is_multi_step)
|
|
||||||
|
|
||||||
# Create input data structures.
|
# Create input data structures.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||||
|
|
||||||
|
# Detect async + multi-step
|
||||||
|
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||||
|
and allow_async_output_proc)
|
||||||
|
|
||||||
ctx = self.scheduler_contexts[virtual_engine]
|
ctx = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
# Clear outputs on scheduler iteration start
|
# Clear outputs on scheduler iteration start
|
||||||
ctx.request_outputs.clear()
|
ctx.request_outputs.clear()
|
||||||
|
|
||||||
|
# Schedule iteration
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
allow_async_output_proc
|
allow_async_output_proc
|
||||||
) = self.scheduler[virtual_engine].schedule()
|
) = self.scheduler[virtual_engine].schedule()
|
||||||
|
|
||||||
# If current scheduler iteration has no async postprocessor,
|
# Detect async + multi-step
|
||||||
# then we need first to drain the pending async postprocessor
|
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||||
# before moving forward
|
and allow_async_output_proc)
|
||||||
|
|
||||||
|
# Maybe switch from async mode to sync mode
|
||||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
is_async=True)
|
is_async=True)
|
||||||
|
|
||||||
|
# For async + multi-step, init the queue
|
||||||
|
if use_async_and_multi_step:
|
||||||
|
assert len(ctx.output_queue) == 0
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
ctx.output_queue.append(
|
||||||
|
(None, seq_group_metadata_list, scheduler_outputs))
|
||||||
|
|
||||||
if (self.scheduler_config.is_multi_step
|
if (self.scheduler_config.is_multi_step
|
||||||
and scheduler_outputs.num_lookahead_slots > 0):
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
# cache the scheduler outputs for the next iteration if we have
|
# cache the scheduler outputs for the next iteration if we have
|
||||||
@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
assert scheduler_outputs is not None
|
assert scheduler_outputs is not None
|
||||||
|
|
||||||
assert not (self.scheduler_config.is_multi_step and \
|
|
||||||
allow_async_output_proc)
|
|
||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
finished_requests_ids = self.scheduler[
|
finished_requests_ids = self.scheduler[
|
||||||
virtual_engine].get_and_reset_finished_requests_ids()
|
virtual_engine].get_and_reset_finished_requests_ids()
|
||||||
@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
last_sampled_token_ids=last_sampled_token_ids)
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
if allow_async_output_proc:
|
if allow_async_output_proc:
|
||||||
execute_model_req.async_callback = self.async_callback[
|
async_callback = self.async_callback_multi_step[
|
||||||
virtual_engine]
|
virtual_engine] if use_async_and_multi_step \
|
||||||
|
else self.async_callback[virtual_engine]
|
||||||
|
|
||||||
|
execute_model_req.async_callback = async_callback
|
||||||
|
execute_model_req.use_async_and_multi_step = \
|
||||||
|
use_async_and_multi_step
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = await self.model_executor.execute_model_async(
|
output = await self.model_executor.execute_model_async(
|
||||||
@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self._update_cached_scheduler_output(virtual_engine, output)
|
self._update_cached_scheduler_output(virtual_engine, output)
|
||||||
else:
|
else:
|
||||||
if len(ctx.output_queue) > 0:
|
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
assert not self.scheduler_config.is_multi_step
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
is_async=True)
|
is_async=True)
|
||||||
@ -362,22 +378,25 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
seq_group.finish_step()
|
seq_group.finish_step()
|
||||||
|
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
# clear the cache if we have finished all the steps
|
# Clear the cache if we have finished all the steps
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self.cached_scheduler_outputs[
|
self.cached_scheduler_outputs[
|
||||||
virtual_engine] = SchedulerOutputState()
|
virtual_engine] = SchedulerOutputState()
|
||||||
|
|
||||||
# Cache results in engine
|
if use_async_and_multi_step:
|
||||||
ctx.output_queue.append(
|
# For async + multi-step, clear the queue
|
||||||
(output, seq_group_metadata_list, scheduler_outputs))
|
ctx.output_queue.clear()
|
||||||
|
else:
|
||||||
|
ctx.output_queue.append(
|
||||||
|
(output, seq_group_metadata_list, scheduler_outputs))
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if output and allow_async_output_proc:
|
||||||
assert len(
|
assert len(
|
||||||
output
|
output
|
||||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
||||||
self._advance_to_next_step(
|
self._advance_to_next_step(
|
||||||
output[0], seq_group_metadata_list,
|
output[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
@ -390,7 +409,11 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ctx.request_outputs = []
|
# Multi-step case
|
||||||
|
if use_async_and_multi_step:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
ctx.request_outputs = []
|
||||||
|
|
||||||
if not self.has_unfinished_requests():
|
if not self.has_unfinished_requests():
|
||||||
# Drain async postprocessor (if exists)
|
# Drain async postprocessor (if exists)
|
||||||
|
@ -91,7 +91,8 @@ class SchedulerOutputState:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SchedulerContext:
|
class SchedulerContext:
|
||||||
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
|
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
|
||||||
|
List[SequenceGroupMetadata],
|
||||||
SchedulerOutputs]] = field(
|
SchedulerOutputs]] = field(
|
||||||
default_factory=lambda: deque())
|
default_factory=lambda: deque())
|
||||||
|
|
||||||
@ -432,6 +433,13 @@ class LLMEngine:
|
|||||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.async_callback_multi_step = [
|
||||||
|
functools.partial(self._process_model_outputs,
|
||||||
|
virtual_engine=v_id,
|
||||||
|
is_async=False)
|
||||||
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
@ -1240,28 +1248,49 @@ class LLMEngine:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def _process_model_outputs(self, virtual_engine: int,
|
def _process_model_outputs(self,
|
||||||
is_async: bool) -> None:
|
virtual_engine: int,
|
||||||
|
is_async: bool,
|
||||||
|
sampler_output: Optional[SamplerOutput] = None,
|
||||||
|
is_last_output: bool = False) -> None:
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||||
|
|
||||||
virtual_engine: The engine id to operate on
|
virtual_engine: The engine id to operate on
|
||||||
|
|
||||||
is_async: Indicates whether this postprocessor runs in
|
is_async: Indicates whether this postprocessor runs in
|
||||||
parallel with the GPU forward pass and is processing
|
parallel with the GPU forward pass and is processing
|
||||||
tokens from the previous step. If this is true, then
|
tokens from the previous step. If this is true, then
|
||||||
no tokens need to be appended since it is already done
|
no tokens need to be appended since it is already done
|
||||||
externally (before the next schedule() call)
|
externally (before the next schedule() call)
|
||||||
|
|
||||||
|
sampler_output: Used with multi-step execution to provide
|
||||||
|
sampler_output of each step
|
||||||
|
is_last_output: Used with multi-step execution to indicate
|
||||||
|
the last step (of each multi-step group)
|
||||||
|
|
||||||
Returns RequestOutputs that can be returned to the client.
|
Returns RequestOutputs that can be returned to the client.
|
||||||
"""
|
"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
|
is_multi_step = sampler_output is not None
|
||||||
|
|
||||||
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
if len(ctx.output_queue) == 0:
|
if len(ctx.output_queue) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
(outputs, seq_group_metadata_list,
|
if is_multi_step:
|
||||||
scheduler_outputs) = ctx.output_queue.popleft()
|
# Async + multi-step case
|
||||||
|
(outputs, seq_group_metadata_list,
|
||||||
|
scheduler_outputs) = ctx.output_queue[0]
|
||||||
|
assert outputs is None
|
||||||
|
outputs = [sampler_output]
|
||||||
|
else:
|
||||||
|
# Async standard case
|
||||||
|
(outputs, seq_group_metadata_list,
|
||||||
|
scheduler_outputs) = ctx.output_queue.popleft()
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(seq_group_metadata_list) == len(
|
assert len(seq_group_metadata_list) == len(
|
||||||
@ -1320,7 +1349,11 @@ class LLMEngine:
|
|||||||
self.output_processor.process_outputs(seq_group, output,
|
self.output_processor.process_outputs(seq_group, output,
|
||||||
is_async)
|
is_async)
|
||||||
|
|
||||||
# Free the finished sequence groups.
|
# For async + multi-step, free finished seqs and create outputs
|
||||||
|
# only on the final step.
|
||||||
|
if is_multi_step and not is_last_output:
|
||||||
|
return
|
||||||
|
|
||||||
for scheduler in self.scheduler:
|
for scheduler in self.scheduler:
|
||||||
scheduler.free_finished_seq_groups()
|
scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
@ -1328,7 +1361,7 @@ class LLMEngine:
|
|||||||
for i, _ in enumerate(seq_group_metadata_list):
|
for i, _ in enumerate(seq_group_metadata_list):
|
||||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
|
||||||
if i in finished_before:
|
if not is_multi_step and i in finished_before:
|
||||||
continue # Avoids double processing
|
continue # Avoids double processing
|
||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
@ -1342,7 +1375,11 @@ class LLMEngine:
|
|||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
if is_async:
|
# For async + multi-step, do stats only on the last output.
|
||||||
|
# Otherwise, do stats if the execution is async
|
||||||
|
do_stats = is_multi_step or is_async
|
||||||
|
|
||||||
|
if do_stats:
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
||||||
|
|
||||||
@ -1437,7 +1474,7 @@ class LLMEngine:
|
|||||||
"as performance will be severely degraded otherwise.")
|
"as performance will be severely degraded otherwise.")
|
||||||
|
|
||||||
# For llm_engine, there is no pipeline parallel support, so the engine
|
# For llm_engine, there is no pipeline parallel support, so the engine
|
||||||
# used is always 0
|
# used is always 0.
|
||||||
virtual_engine = 0
|
virtual_engine = 0
|
||||||
|
|
||||||
# These are cached outputs from previous iterations. None if on first
|
# These are cached outputs from previous iterations. None if on first
|
||||||
@ -1447,6 +1484,10 @@ class LLMEngine:
|
|||||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||||
|
|
||||||
|
# Detect async + multi-step
|
||||||
|
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||||
|
and allow_async_output_proc)
|
||||||
|
|
||||||
ctx = self.scheduler_contexts[virtual_engine]
|
ctx = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
@ -1462,11 +1503,22 @@ class LLMEngine:
|
|||||||
allow_async_output_proc
|
allow_async_output_proc
|
||||||
) = self.scheduler[virtual_engine].schedule()
|
) = self.scheduler[virtual_engine].schedule()
|
||||||
|
|
||||||
|
# Detect async + multi-step
|
||||||
|
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||||
|
and allow_async_output_proc)
|
||||||
|
|
||||||
# Maybe switch from async mode to sync mode
|
# Maybe switch from async mode to sync mode
|
||||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
is_async=True)
|
is_async=True)
|
||||||
|
|
||||||
|
# For async + multi-step, init the queue
|
||||||
|
if use_async_and_multi_step:
|
||||||
|
assert len(ctx.output_queue) == 0
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
ctx.output_queue.append(
|
||||||
|
(None, seq_group_metadata_list, scheduler_outputs))
|
||||||
|
|
||||||
if (self.scheduler_config.is_multi_step
|
if (self.scheduler_config.is_multi_step
|
||||||
and scheduler_outputs.num_lookahead_slots > 0):
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
# cache the scheduler outputs for the next iteration if we have
|
# cache the scheduler outputs for the next iteration if we have
|
||||||
@ -1478,9 +1530,6 @@ class LLMEngine:
|
|||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
assert scheduler_outputs is not None
|
assert scheduler_outputs is not None
|
||||||
|
|
||||||
assert not (self.scheduler_config.is_multi_step and \
|
|
||||||
allow_async_output_proc)
|
|
||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
finished_requests_ids = self.scheduler[
|
finished_requests_ids = self.scheduler[
|
||||||
virtual_engine].get_and_reset_finished_requests_ids()
|
virtual_engine].get_and_reset_finished_requests_ids()
|
||||||
@ -1505,8 +1554,13 @@ class LLMEngine:
|
|||||||
last_sampled_token_ids=last_sampled_token_ids)
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
if allow_async_output_proc:
|
if allow_async_output_proc:
|
||||||
execute_model_req.async_callback = self.async_callback[
|
async_callback = self.async_callback_multi_step[
|
||||||
virtual_engine]
|
virtual_engine] if use_async_and_multi_step \
|
||||||
|
else self.async_callback[virtual_engine]
|
||||||
|
|
||||||
|
execute_model_req.async_callback = async_callback
|
||||||
|
execute_model_req.use_async_and_multi_step = \
|
||||||
|
use_async_and_multi_step
|
||||||
|
|
||||||
output = self.model_executor.execute_model(
|
output = self.model_executor.execute_model(
|
||||||
execute_model_req=execute_model_req)
|
execute_model_req=execute_model_req)
|
||||||
@ -1518,7 +1572,7 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
# Nothing scheduled => If there is pending async postprocessor,
|
# Nothing scheduled => If there is pending async postprocessor,
|
||||||
# then finish it here.
|
# then finish it here.
|
||||||
if len(ctx.output_queue) > 0:
|
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
assert not self.scheduler_config.is_multi_step
|
||||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
is_async=True)
|
is_async=True)
|
||||||
@ -1535,18 +1589,23 @@ class LLMEngine:
|
|||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||||
|
|
||||||
# Add results to the output_queue
|
if use_async_and_multi_step:
|
||||||
# (for async or non-async postprocessing)
|
# For async + multi-step, clear the queue
|
||||||
ctx.output_queue.append(
|
ctx.output_queue.clear()
|
||||||
(output, seq_group_metadata_list, scheduler_outputs))
|
else:
|
||||||
|
# Add results to the output_queue
|
||||||
|
# (for async or non-async postprocessing)
|
||||||
|
ctx.output_queue.append(
|
||||||
|
(output, seq_group_metadata_list, scheduler_outputs))
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if output and allow_async_output_proc:
|
||||||
assert len(output) == 1, ("Multi step decoding does not work "
|
assert len(output) == 1, (
|
||||||
"with async output processing.")
|
"Multi step decoding does not work "
|
||||||
|
"with async output processing.")
|
||||||
|
|
||||||
self._advance_to_next_step(
|
self._advance_to_next_step(
|
||||||
output[0], seq_group_metadata_list,
|
output[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
# Check if need to run the usual non-async path
|
# Check if need to run the usual non-async path
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
@ -1560,7 +1619,10 @@ class LLMEngine:
|
|||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
else:
|
else:
|
||||||
# Multi-step case
|
# Multi-step case
|
||||||
ctx.request_outputs = []
|
if use_async_and_multi_step:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
ctx.request_outputs = []
|
||||||
|
|
||||||
if not self.has_unfinished_requests():
|
if not self.has_unfinished_requests():
|
||||||
# Drain async postprocessor (if exists)
|
# Drain async postprocessor (if exists)
|
||||||
|
@ -1295,6 +1295,7 @@ class ExecuteModelRequest(
|
|||||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
# Async callback
|
# Async callback
|
||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
|
use_async_and_multi_step: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_first_multi_step(self) -> bool:
|
def is_first_multi_step(self) -> bool:
|
||||||
@ -1341,4 +1342,5 @@ class ExecuteModelRequest(
|
|||||||
finished_requests_ids=self.finished_requests_ids,
|
finished_requests_ids=self.finished_requests_ids,
|
||||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||||
if self.last_sampled_token_ids is not None else None,
|
if self.last_sampled_token_ids is not None else None,
|
||||||
async_callback=self.async_callback)
|
async_callback=self.async_callback,
|
||||||
|
use_async_and_multi_step=self.use_async_and_multi_step)
|
||||||
|
@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
finished_requests_ids: Optional[List[str]] = None
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
|
use_async_and_multi_step: bool = False
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
@ -215,6 +217,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
)
|
)
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
|
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||||
|
output_proc_callback: Callable):
|
||||||
|
# Proceed with pythonization and output_proc in order.
|
||||||
|
# Stop on the first one that fails to pythonize
|
||||||
|
cont = True
|
||||||
|
for model_output in model_input.cached_outputs:
|
||||||
|
if not model_output.pythonized:
|
||||||
|
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
if model_output.pythonized:
|
||||||
|
output_proc_callback(
|
||||||
|
sampler_output=model_output.sampler_output)
|
||||||
|
else:
|
||||||
|
cont = False
|
||||||
|
|
||||||
|
if not cont:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _final_process_outputs(self, model_input: StatefulModelInput,
|
||||||
|
output_proc_callback: Optional[Callable]):
|
||||||
|
assert model_input.frozen_model_input is not None
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for output_id in range(len(model_input.cached_outputs)):
|
||||||
|
is_last_output = output_id == len(model_input.cached_outputs) - 1
|
||||||
|
|
||||||
|
output = model_input.cached_outputs[output_id]
|
||||||
|
if not output.pythonized:
|
||||||
|
output.pythonize(model_input, self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
|
||||||
|
if model_input.frozen_model_input.use_async_and_multi_step:
|
||||||
|
assert output_proc_callback is not None
|
||||||
|
output_proc_callback(sampler_output=output.sampler_output,
|
||||||
|
is_last_output=is_last_output)
|
||||||
|
|
||||||
|
outputs.append(output.sampler_output)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -271,6 +313,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
model_input = self._advance_step(
|
model_input = self._advance_step(
|
||||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||||
|
|
||||||
|
output_proc_callback = None
|
||||||
|
if frozen_model_input.use_async_and_multi_step:
|
||||||
|
output_proc_callback = frozen_model_input.async_callback
|
||||||
|
assert output_proc_callback is not None
|
||||||
|
async_callback = functools.partial(
|
||||||
|
self._async_process_outputs,
|
||||||
|
model_input=model_input,
|
||||||
|
output_proc_callback=output_proc_callback)
|
||||||
|
|
||||||
|
frozen_model_input = dataclasses.replace( # type: ignore
|
||||||
|
model_input.frozen_model_input,
|
||||||
|
async_callback=async_callback)
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
|
||||||
# Execute the model
|
# Execute the model
|
||||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
@ -301,9 +357,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
output[0].logprobs = None
|
output[0].logprobs = None
|
||||||
# Pythonize the output if CPU is ahead and the previous step is
|
# Pythonize the output if CPU is ahead and the previous step is
|
||||||
# ready.
|
# ready.
|
||||||
for model_output in model_input.cached_outputs:
|
if not frozen_model_input.use_async_and_multi_step:
|
||||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
for model_output in model_input.cached_outputs:
|
||||||
self.pinned_sampled_token_ids)
|
model_output.maybe_pythonize(model_input,
|
||||||
|
self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
|
||||||
model_input.current_step += 1
|
model_input.current_step += 1
|
||||||
|
|
||||||
@ -316,11 +374,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
|
|
||||||
# Pythonize the output and block if needed since it is the last step
|
# Pythonize the output and block if needed since it is the last step
|
||||||
if model_input.is_last_step:
|
if model_input.is_last_step:
|
||||||
outputs = []
|
outputs = self._final_process_outputs(model_input,
|
||||||
for output in model_input.cached_outputs:
|
output_proc_callback)
|
||||||
output.pythonize(model_input, self._copy_stream,
|
|
||||||
self.pinned_sampled_token_ids)
|
|
||||||
outputs.append(output.sampler_output)
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# should be [SamplerOutput]
|
# should be [SamplerOutput]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -61,6 +62,13 @@ class MultiStepWorker(Worker):
|
|||||||
execute_model_req.seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
execute_model_req.virtual_engine,
|
execute_model_req.virtual_engine,
|
||||||
execute_model_req.finished_requests_ids))
|
execute_model_req.finished_requests_ids))
|
||||||
|
|
||||||
|
if execute_model_req.async_callback:
|
||||||
|
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||||
|
model_input.frozen_model_input,
|
||||||
|
async_callback=execute_model_req.async_callback,
|
||||||
|
use_async_and_multi_step=execute_model_req.
|
||||||
|
use_async_and_multi_step)
|
||||||
else:
|
else:
|
||||||
# on subsequent steps we reuse the worker input and model input
|
# on subsequent steps we reuse the worker input and model input
|
||||||
multi_step_state = self.multi_step_states[virtual_engine]
|
multi_step_state = self.multi_step_states[virtual_engine]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user