[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling (#9038)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-10-06 15:48:11 -04:00 committed by GitHub
parent fdf59d30ea
commit cb3b2b9ba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 179 additions and 110 deletions

View File

@ -0,0 +1,81 @@
import pytest
from tests.conftest import VllmRunner
from tests.core.utils import create_dummy_prompt
from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform
from vllm.sequence import SequenceGroup
MODEL = "JackFram/llama-160m"
def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
scheduler = engine.scheduler[0]
scheduler.add_seq_group(seq_group)
@pytest.mark.parametrize("num_scheduler_steps", [1, 8])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_num_computed_tokens_update(num_scheduler_steps: int,
enable_chunked_prefill: bool,
enforce_eager: bool):
is_multi_step = num_scheduler_steps > 1
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill
if is_multi_step_chunked_prefill and current_platform.is_rocm():
pytest.skip("Multi-step with Chunked-Prefill does not support "
"rocm_flash_attn backend")
# Make a vllm engine
runner = VllmRunner(model_name=MODEL,
gpu_memory_utilization=0.7,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
enable_chunked_prefill=enable_chunked_prefill,
enforce_eager=enforce_eager)
engine: LLMEngine = runner.model.llm_engine
# In multi-step + chunked-prefill there is no separate single prompt step.
# What is scheduled will run for num_scheduler_steps always.
num_prompt_steps = num_scheduler_steps \
if is_multi_step_chunked_prefill else 1
num_output_tokens_list = [4, 8, 12, 15, 16, 17]
# Create sequence and add to engine
prompt_len = 10
for req_idx, num_output_tokens in enumerate(num_output_tokens_list):
seq, seq_group = create_dummy_prompt(request_id=str(req_idx),
prompt_length=prompt_len,
min_tokens=num_output_tokens,
max_tokens=num_output_tokens)
add_seq_group_to_engine(engine, seq_group)
assert seq.data.get_num_computed_tokens() == 0
for _ in range(num_prompt_steps):
# prompt steps
engine.step()
if not seq.is_finished():
prompt_num_computed_tokens = seq.data.get_num_computed_tokens()
# Test correctness of num_computed_tokens after the prompt steps
assert prompt_num_computed_tokens == \
prompt_len + num_prompt_steps - 1
decode_step_counter = 0
while not seq.is_finished():
# Test correctness of num_computed_tokens after the decode steps
assert seq.data.get_num_computed_tokens(
) == prompt_num_computed_tokens + decode_step_counter
for _ in range(num_scheduler_steps):
# decode step
engine.step()
decode_step_counter += 1
# Test correctness of num_computed_tokens after the sequence finish.
assert seq.data.get_num_computed_tokens(
) == prompt_len + num_output_tokens - 1

View File

@ -16,6 +16,8 @@ def create_dummy_prompt(
use_beam_search: bool = False, use_beam_search: bool = False,
best_of: int = 1, best_of: int = 1,
prompt_tokens: Optional[List[int]] = None, prompt_tokens: Optional[List[int]] = None,
min_tokens: int = 0,
max_tokens: int = 16,
) -> Tuple[Sequence, SequenceGroup]: ) -> Tuple[Sequence, SequenceGroup]:
if not block_size: if not block_size:
block_size = prompt_length block_size = prompt_length
@ -36,7 +38,9 @@ def create_dummy_prompt(
arrival_time=time.time(), arrival_time=time.time(),
sampling_params=SamplingParams( sampling_params=SamplingParams(
use_beam_search=use_beam_search, use_beam_search=use_beam_search,
best_of=best_of), best_of=best_of,
max_tokens=max_tokens,
min_tokens=min_tokens),
lora_request=lora_request) lora_request=lora_request)
return prompt, seq_group return prompt, seq_group

View File

@ -191,12 +191,22 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
) )
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor], sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int): block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
""" """
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
# When using cudagraph, the num_seqs is padded to the next captured # When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in # batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries # the batch. For --enforce-eager mode, num_seqs == num_queries

View File

@ -962,6 +962,45 @@ class LLMEngine:
return return
def _update_num_computed_tokens_for_multi_step_prefill(
self, seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata,
is_first_step_output: Optional[bool]):
"""
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
in multi-step are submitted in a single burst.
"""
assert self.scheduler_config.is_multi_step
if not seq_group_meta.is_prompt:
# num_computed_token updates for multi-step decodes happen after
# the tokens are appended to the sequence.
return
do_update: bool = False
if self.scheduler_config.chunked_prefill_enabled:
# In multi-step + chunked-prefill case, the prompt sequences
# that are scheduled are fully processed in the first step.
do_update = is_first_step_output is None or is_first_step_output
else:
# Normal multi-step decoding case. In this case prompt-sequences
# are actually single-stepped. Always update in this case.
assert seq_group.state.num_steps == 1
do_update = True
if do_update:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
def _process_model_outputs(self, def _process_model_outputs(self,
ctx: SchedulerContext, ctx: SchedulerContext,
request_id: Optional[str] = None) -> None: request_id: Optional[str] = None) -> None:
@ -972,64 +1011,6 @@ class LLMEngine:
request_id: If provided, then only this request is going to be processed request_id: If provided, then only this request is going to be processed
""" """
def update_prefill_num_computed_tokens(
seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
is_first_step_output: Optional[bool]) -> None:
"""
When multi-step and chunked-prefill are enabled together, the
prefill sequence scheduled for multi-step execution turn into
decodes in the first step itself. This function accounts
for that conversion.
seq_group: SequenceGroup - A prefill seq_group
seq_group_meta: SequenceGroupMetadata - Metadata of the given
prefill seq_group
num_outputs: int - number of output tokens being processed for the
given seq_group
is_first_step_output: Optional[bool] -
If multi-step is enabled and num_outputs is 1, this value
indicates if this outputs belongs to the first step in the
multi-step.
If multi-step is enabled and num_outputs > 1, this value
must be None, as num_outputs > 1 indicates that outputs from
all the steps in multi-step are submitted in a single burst.
When multi-step is disabled, this value is always True.
"""
assert seq_group_meta.is_prompt
token_chunk_size = seq_group_meta.token_chunk_size
if num_outputs == 1:
assert is_first_step_output is not None
if seq_group_meta.state.num_steps == 1:
assert is_first_step_output is True
seq_group.update_num_computed_tokens(token_chunk_size)
return
# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
if is_first_step_output is True:
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
return
assert is_first_step_output is None
# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill. Outputs from all the steps are
# submitted in a single burst.
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
assert num_outputs == seq_group_meta.state.num_steps, \
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
now = time.time() now = time.time()
if len(ctx.output_queue) == 0: if len(ctx.output_queue) == 0:
@ -1090,7 +1071,7 @@ class LLMEngine:
seq_group_meta = seq_group_metadata_list[i] seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group seq_group: SequenceGroup = scheduled_seq_group.seq_group
if seq_group.is_finished(): if seq_group.is_finished():
finished_before.append(i) finished_before.append(i)
@ -1101,14 +1082,14 @@ class LLMEngine:
else: else:
output = [outputs_by_sequence_group[0][i]] output = [outputs_by_sequence_group[0][i]]
if not is_async and seq_group_meta.is_prompt: if not is_async:
# Updates for all decodes happen when we actually append the if self.scheduler_config.is_multi_step:
# token ids to the seq in process_outputs. # Updates happen only if the sequence is prefill
update_prefill_num_computed_tokens(seq_group, seq_group_meta, self._update_num_computed_tokens_for_multi_step_prefill(
len(output), seq_group, seq_group_meta, is_first_step_output)
is_first_step_output) else:
elif not is_async: seq_group.update_num_computed_tokens(
seq_group.update_num_computed_tokens(1) seq_group_meta.token_chunk_size)
if outputs: if outputs:
for o in outputs: for o in outputs:
@ -1132,16 +1113,8 @@ class LLMEngine:
else: else:
self.output_processor.process_prompt_logprob(seq_group, output) self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample: if seq_group_meta.do_sample:
output_token_num = self.output_processor.process_outputs( self.output_processor.process_outputs(
seq_group, output, is_async) seq_group, output, is_async)
if self.speculative_config:
# We -1 here because we always
# (w/o speculative decoding) add the number of
# computed tokens by one in the decoding phase.
# Therefore, we remove that one token that
# is already added.
seq_group.update_num_computed_tokens(output_token_num -
1)
if seq_group.is_finished(): if seq_group.is_finished():
finished_now.append(i) finished_now.append(i)
@ -1250,20 +1223,15 @@ class LLMEngine:
if seq_group.is_finished(): if seq_group.is_finished():
continue continue
if seq_group_metadata.is_prompt: if self.scheduler_config.is_multi_step:
if self.scheduler_config.is_multi_step and \ # Updates happen only if the sequence is prefill
self.scheduler_config.chunked_prefill_enabled: self._update_num_computed_tokens_for_multi_step_prefill(
# Prompts are scheduled in multi-step only when seq_group, seq_group_metadata,
# chunking is enabled. These prompts turn into seq_group.state.num_steps == 1)
# decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens
# here.
seq_group.update_num_computed_tokens(1)
else: else:
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size) seq_group_metadata.token_chunk_size)
else:
seq_group.update_num_computed_tokens(1)
if seq_group_metadata.do_sample: if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, ( assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample" "Async output processor expects a single sample"
@ -1273,6 +1241,14 @@ class LLMEngine:
assert len(seq_group.seqs) == 1 assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0] seq = seq_group.seqs[0]
if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List, Optional from typing import Callable, List
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
@ -58,14 +58,10 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod @abstractmethod
def process_outputs(self, sequence_group: SequenceGroup, def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput], outputs: List[SequenceGroupOutput],
is_async: bool) -> Optional[int]: is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as """Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the detokenization, stop checking, and freeing/forking sequences in the
scheduler. scheduler.
Return the number of new tokens generated in the sequence group.
The returned value is optional because it is only used for
speculative decoding mqa scorer.
""" """
pass pass

View File

@ -1,5 +1,5 @@
import functools import functools
from typing import Callable, List, Optional from typing import Callable, List
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def process_outputs(self, def process_outputs(self,
sequence_group: SequenceGroup, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput], outputs: List[SequenceGroupOutput],
is_async: bool = False) -> Optional[int]: is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group. """Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than This only supports sequence groups of size 1. It supports greater than
@ -84,10 +84,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
tokens from the previous step. If this is true, then tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done no tokens need to be appended since it is already done
externally (before the next schedule() call) externally (before the next schedule() call)
Returns:
The number of tokens appended to the sequence. This is optional
because only speculative decode uses this return value.
""" """
# Sequences can be in RUNNING or FINISHED_ABORTED state # Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED # once scheduled, as a sequence is moved to FINSIHED_ABORTED
@ -110,7 +106,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# was already appended, so we only need to do the rest of the # was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic # postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params) self._process_decode_and_stop(seq, sequence_group.sampling_params)
return None
else: else:
# Standard multi-step case # Standard multi-step case
@ -126,7 +121,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
] ]
assert valid_samples assert valid_samples
return self._process_seq_outputs(seq, valid_samples, self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params) sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence, def _process_decode_and_stop(self, seq: Sequence,
@ -145,7 +140,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_seq_outputs(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput], valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> int: sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples] output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples]
@ -168,6 +163,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
output_token_ids = output_token_ids[:i + 1] output_token_ids = output_token_ids[:i + 1]
break break
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new # Incrementally append tokens to the sequence, as if we had only one new
# token. # token.
for output_token_id, output_logprob in zip(output_token_ids, for output_token_id, output_logprob in zip(output_token_ids,
@ -177,8 +173,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs=output_logprob, logprobs=output_logprob,
) )
if is_prefill_sampled_token:
is_prefill_sampled_token = False
else:
# Update num_computed_tokens iff the sampled token is not from
# a prefill step.
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params) self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished(): if seq.is_finished():
break break
return len(output_token_ids)