[Core] Combine async postprocessor and multi-step (#7921)

This commit is contained in:
Alexander Matveev 2024-08-29 14:18:26 -04:00 committed by GitHub
parent f205c09854
commit 3f60f2244e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 215 additions and 65 deletions

View File

@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
is_async: bool):
prompts = example_prompts
if len(prompts) < num_prompts:
@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]
# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"]
if eager_mode:
ms_server_args.append("--enforce-eager")

View File

@ -1107,10 +1107,7 @@ class Scheduler:
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
allow_async_output_proc: bool = self.use_async_output_proc
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []

View File

@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine]
# skip the scheduler if there are any remaining steps in the seq groups.
@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
# Execute the model.
output = await self.model_executor.execute_model_async(
@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
@ -362,22 +378,25 @@ class _AsyncLLMEngine(LLMEngine):
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
# Cache results in engine
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
@ -390,7 +409,11 @@ class _AsyncLLMEngine(LLMEngine):
self.do_tracing(scheduler_outputs)
else:
ctx.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)

View File

@ -91,7 +91,8 @@ class SchedulerOutputState:
@dataclass
class SchedulerContext:
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata],
SchedulerOutputs]] = field(
default_factory=lambda: deque())
@ -432,6 +433,13 @@ class LLMEngine:
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callback_multi_step = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=False)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -1240,28 +1248,49 @@ class LLMEngine:
return
def _process_model_outputs(self, virtual_engine: int,
is_async: bool) -> None:
def _process_model_outputs(self,
virtual_engine: int,
is_async: bool,
sampler_output: Optional[SamplerOutput] = None,
is_last_output: bool = False) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
is_multi_step = sampler_output is not None
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
if len(ctx.output_queue) == 0:
return None
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()
if is_multi_step:
# Async + multi-step case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue[0]
assert outputs is None
outputs = [sampler_output]
else:
# Async standard case
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()
assert outputs is not None
# Sanity check
assert len(seq_group_metadata_list) == len(
@ -1320,7 +1349,11 @@ class LLMEngine:
self.output_processor.process_outputs(seq_group, output,
is_async)
# Free the finished sequence groups.
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if is_multi_step and not is_last_output:
return
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
@ -1328,7 +1361,7 @@ class LLMEngine:
for i, _ in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
if i in finished_before:
if not is_multi_step and i in finished_before:
continue # Avoids double processing
seq_group = scheduled_seq_group.seq_group
@ -1342,7 +1375,11 @@ class LLMEngine:
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if is_async:
# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
do_stats = is_multi_step or is_async
if do_stats:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before)
@ -1437,7 +1474,7 @@ class LLMEngine:
"as performance will be severely degraded otherwise.")
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
# used is always 0.
virtual_engine = 0
# These are cached outputs from previous iterations. None if on first
@ -1447,6 +1484,10 @@ class LLMEngine:
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine]
# Skip the scheduler if there are any remaining steps in the seq groups.
@ -1462,11 +1503,22 @@ class LLMEngine:
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
@ -1478,9 +1530,6 @@ class LLMEngine:
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
@ -1505,8 +1554,13 @@ class LLMEngine:
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
@ -1518,7 +1572,7 @@ class LLMEngine:
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
@ -1535,18 +1589,23 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
assert len(output) == 1, ("Multi step decoding does not work "
"with async output processing.")
if output and allow_async_output_proc:
assert len(output) == 1, (
"Multi step decoding does not work "
"with async output processing.")
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
@ -1560,7 +1619,10 @@ class LLMEngine:
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
ctx.request_outputs = []
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)

View File

@ -1295,6 +1295,7 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
use_async_and_multi_step: bool = False
@property
def is_first_multi_step(self) -> bool:
@ -1341,4 +1342,5 @@ class ExecuteModelRequest(
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)
async_callback=self.async_callback,
use_async_and_multi_step=self.use_async_and_multi_step)

View File

@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
async_callback: Optional[Callable] = None
use_async_and_multi_step: bool = False
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {

View File

@ -1,5 +1,7 @@
import dataclasses
import functools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
@ -215,6 +217,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
)
return model_input
def _async_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
cont = True
for model_output in model_input.cached_outputs:
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_output.pythonized:
output_proc_callback(
sampler_output=model_output.sampler_output)
else:
cont = False
if not cont:
break
def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
assert model_input.frozen_model_input is not None
outputs = []
for output_id in range(len(model_input.cached_outputs)):
is_last_output = output_id == len(model_input.cached_outputs) - 1
output = model_input.cached_outputs[output_id]
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_input.frozen_model_input.use_async_and_multi_step:
assert output_proc_callback is not None
output_proc_callback(sampler_output=output.sampler_output,
is_last_output=is_last_output)
outputs.append(output.sampler_output)
return outputs
@torch.inference_mode()
def execute_model(
self,
@ -271,6 +313,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output)
output_proc_callback = None
if frozen_model_input.use_async_and_multi_step:
output_proc_callback = frozen_model_input.async_callback
assert output_proc_callback is not None
async_callback = functools.partial(
self._async_process_outputs,
model_input=model_input,
output_proc_callback=output_proc_callback)
frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
assert frozen_model_input is not None
# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches,
@ -301,9 +357,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if not frozen_model_input.use_async_and_multi_step:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input,
self._copy_stream,
self.pinned_sampled_token_ids)
model_input.current_step += 1
@ -316,11 +374,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = []
for output in model_input.cached_outputs:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
outputs = self._final_process_outputs(model_input,
output_proc_callback)
return outputs
# should be [SamplerOutput]

View File

@ -1,3 +1,4 @@
import dataclasses
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -61,6 +62,13 @@ class MultiStepWorker(Worker):
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=execute_model_req.async_callback,
use_async_and_multi_step=execute_model_req.
use_async_and_multi_step)
else:
# on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine]