[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling (#9038)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
fdf59d30ea
commit
cb3b2b9ba4
81
tests/core/test_num_computed_tokens_update.py
Normal file
81
tests/core/test_num_computed_tokens_update.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.conftest import VllmRunner
|
||||||
|
from tests.core.utils import create_dummy_prompt
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sequence import SequenceGroup
|
||||||
|
|
||||||
|
MODEL = "JackFram/llama-160m"
|
||||||
|
|
||||||
|
|
||||||
|
def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
|
||||||
|
scheduler = engine.scheduler[0]
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_scheduler_steps", [1, 8])
|
||||||
|
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||||
|
def test_num_computed_tokens_update(num_scheduler_steps: int,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
enforce_eager: bool):
|
||||||
|
|
||||||
|
is_multi_step = num_scheduler_steps > 1
|
||||||
|
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill
|
||||||
|
|
||||||
|
if is_multi_step_chunked_prefill and current_platform.is_rocm():
|
||||||
|
pytest.skip("Multi-step with Chunked-Prefill does not support "
|
||||||
|
"rocm_flash_attn backend")
|
||||||
|
|
||||||
|
# Make a vllm engine
|
||||||
|
runner = VllmRunner(model_name=MODEL,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
num_scheduler_steps=num_scheduler_steps,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
enforce_eager=enforce_eager)
|
||||||
|
engine: LLMEngine = runner.model.llm_engine
|
||||||
|
|
||||||
|
# In multi-step + chunked-prefill there is no separate single prompt step.
|
||||||
|
# What is scheduled will run for num_scheduler_steps always.
|
||||||
|
num_prompt_steps = num_scheduler_steps \
|
||||||
|
if is_multi_step_chunked_prefill else 1
|
||||||
|
|
||||||
|
num_output_tokens_list = [4, 8, 12, 15, 16, 17]
|
||||||
|
|
||||||
|
# Create sequence and add to engine
|
||||||
|
prompt_len = 10
|
||||||
|
|
||||||
|
for req_idx, num_output_tokens in enumerate(num_output_tokens_list):
|
||||||
|
seq, seq_group = create_dummy_prompt(request_id=str(req_idx),
|
||||||
|
prompt_length=prompt_len,
|
||||||
|
min_tokens=num_output_tokens,
|
||||||
|
max_tokens=num_output_tokens)
|
||||||
|
add_seq_group_to_engine(engine, seq_group)
|
||||||
|
|
||||||
|
assert seq.data.get_num_computed_tokens() == 0
|
||||||
|
|
||||||
|
for _ in range(num_prompt_steps):
|
||||||
|
# prompt steps
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
if not seq.is_finished():
|
||||||
|
prompt_num_computed_tokens = seq.data.get_num_computed_tokens()
|
||||||
|
# Test correctness of num_computed_tokens after the prompt steps
|
||||||
|
assert prompt_num_computed_tokens == \
|
||||||
|
prompt_len + num_prompt_steps - 1
|
||||||
|
|
||||||
|
decode_step_counter = 0
|
||||||
|
while not seq.is_finished():
|
||||||
|
# Test correctness of num_computed_tokens after the decode steps
|
||||||
|
assert seq.data.get_num_computed_tokens(
|
||||||
|
) == prompt_num_computed_tokens + decode_step_counter
|
||||||
|
for _ in range(num_scheduler_steps):
|
||||||
|
# decode step
|
||||||
|
engine.step()
|
||||||
|
decode_step_counter += 1
|
||||||
|
|
||||||
|
# Test correctness of num_computed_tokens after the sequence finish.
|
||||||
|
assert seq.data.get_num_computed_tokens(
|
||||||
|
) == prompt_len + num_output_tokens - 1
|
@ -16,6 +16,8 @@ def create_dummy_prompt(
|
|||||||
use_beam_search: bool = False,
|
use_beam_search: bool = False,
|
||||||
best_of: int = 1,
|
best_of: int = 1,
|
||||||
prompt_tokens: Optional[List[int]] = None,
|
prompt_tokens: Optional[List[int]] = None,
|
||||||
|
min_tokens: int = 0,
|
||||||
|
max_tokens: int = 16,
|
||||||
) -> Tuple[Sequence, SequenceGroup]:
|
) -> Tuple[Sequence, SequenceGroup]:
|
||||||
if not block_size:
|
if not block_size:
|
||||||
block_size = prompt_length
|
block_size = prompt_length
|
||||||
@ -36,7 +38,9 @@ def create_dummy_prompt(
|
|||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
use_beam_search=use_beam_search,
|
use_beam_search=use_beam_search,
|
||||||
best_of=best_of),
|
best_of=best_of,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
min_tokens=min_tokens),
|
||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
|
|
||||||
return prompt, seq_group
|
return prompt, seq_group
|
||||||
|
@ -191,12 +191,22 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
)
|
)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
|
def advance_step(self,
|
||||||
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
sampled_token_ids: Optional[torch.Tensor],
|
||||||
block_size: int, num_seqs: int, num_queries: int):
|
block_size: int,
|
||||||
|
num_seqs: int,
|
||||||
|
num_queries: int,
|
||||||
|
turn_prefills_into_decodes: bool = False):
|
||||||
"""
|
"""
|
||||||
Update metadata in-place to advance one decode step.
|
Update metadata in-place to advance one decode step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert not turn_prefills_into_decodes, \
|
||||||
|
("Chunked prefill is not supported with rocm_flash_attn yet."
|
||||||
|
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
||||||
|
"specific parameter.")
|
||||||
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
# When using cudagraph, the num_seqs is padded to the next captured
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
# batch sized, but num_queries tracks the actual number of requests in
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||||
|
@ -962,6 +962,45 @@ class LLMEngine:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _update_num_computed_tokens_for_multi_step_prefill(
|
||||||
|
self, seq_group: SequenceGroup,
|
||||||
|
seq_group_meta: SequenceGroupMetadata,
|
||||||
|
is_first_step_output: Optional[bool]):
|
||||||
|
"""
|
||||||
|
This function updates num_computed_tokens for prompt sequences
|
||||||
|
when Multi-Step is enabled.
|
||||||
|
|
||||||
|
seq_group: SequenceGroup to update the num_computed_tokens for.
|
||||||
|
seq_group_meta: Metadata of the given SequenceGroup.
|
||||||
|
is_first_step_output: Optional[bool] -
|
||||||
|
When available, is_first_step_output indicates if the appended
|
||||||
|
output token is the output of the first-step in multi-step.
|
||||||
|
A value of None indicates that outputs from all steps in
|
||||||
|
in multi-step are submitted in a single burst.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.scheduler_config.is_multi_step
|
||||||
|
|
||||||
|
if not seq_group_meta.is_prompt:
|
||||||
|
# num_computed_token updates for multi-step decodes happen after
|
||||||
|
# the tokens are appended to the sequence.
|
||||||
|
return
|
||||||
|
|
||||||
|
do_update: bool = False
|
||||||
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
|
# In multi-step + chunked-prefill case, the prompt sequences
|
||||||
|
# that are scheduled are fully processed in the first step.
|
||||||
|
do_update = is_first_step_output is None or is_first_step_output
|
||||||
|
else:
|
||||||
|
# Normal multi-step decoding case. In this case prompt-sequences
|
||||||
|
# are actually single-stepped. Always update in this case.
|
||||||
|
assert seq_group.state.num_steps == 1
|
||||||
|
do_update = True
|
||||||
|
|
||||||
|
if do_update:
|
||||||
|
seq_group.update_num_computed_tokens(
|
||||||
|
seq_group_meta.token_chunk_size)
|
||||||
|
|
||||||
def _process_model_outputs(self,
|
def _process_model_outputs(self,
|
||||||
ctx: SchedulerContext,
|
ctx: SchedulerContext,
|
||||||
request_id: Optional[str] = None) -> None:
|
request_id: Optional[str] = None) -> None:
|
||||||
@ -972,64 +1011,6 @@ class LLMEngine:
|
|||||||
request_id: If provided, then only this request is going to be processed
|
request_id: If provided, then only this request is going to be processed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update_prefill_num_computed_tokens(
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
|
|
||||||
is_first_step_output: Optional[bool]) -> None:
|
|
||||||
"""
|
|
||||||
When multi-step and chunked-prefill are enabled together, the
|
|
||||||
prefill sequence scheduled for multi-step execution turn into
|
|
||||||
decodes in the first step itself. This function accounts
|
|
||||||
for that conversion.
|
|
||||||
|
|
||||||
seq_group: SequenceGroup - A prefill seq_group
|
|
||||||
seq_group_meta: SequenceGroupMetadata - Metadata of the given
|
|
||||||
prefill seq_group
|
|
||||||
num_outputs: int - number of output tokens being processed for the
|
|
||||||
given seq_group
|
|
||||||
is_first_step_output: Optional[bool] -
|
|
||||||
If multi-step is enabled and num_outputs is 1, this value
|
|
||||||
indicates if this outputs belongs to the first step in the
|
|
||||||
multi-step.
|
|
||||||
If multi-step is enabled and num_outputs > 1, this value
|
|
||||||
must be None, as num_outputs > 1 indicates that outputs from
|
|
||||||
all the steps in multi-step are submitted in a single burst.
|
|
||||||
When multi-step is disabled, this value is always True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert seq_group_meta.is_prompt
|
|
||||||
|
|
||||||
token_chunk_size = seq_group_meta.token_chunk_size
|
|
||||||
|
|
||||||
if num_outputs == 1:
|
|
||||||
assert is_first_step_output is not None
|
|
||||||
|
|
||||||
if seq_group_meta.state.num_steps == 1:
|
|
||||||
assert is_first_step_output is True
|
|
||||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
|
||||||
return
|
|
||||||
|
|
||||||
# multi-step prefill is only supported when multi-step is
|
|
||||||
# enabled with chunked prefill
|
|
||||||
assert self.scheduler_config.is_multi_step and \
|
|
||||||
self.scheduler_config.chunked_prefill_enabled
|
|
||||||
if is_first_step_output is True:
|
|
||||||
# This sequence is a prompt during the first step only.
|
|
||||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
|
||||||
return
|
|
||||||
|
|
||||||
assert is_first_step_output is None
|
|
||||||
|
|
||||||
# multi-step prefill is only supported when multi-step is
|
|
||||||
# enabled with chunked prefill. Outputs from all the steps are
|
|
||||||
# submitted in a single burst.
|
|
||||||
assert self.scheduler_config.is_multi_step and \
|
|
||||||
self.scheduler_config.chunked_prefill_enabled
|
|
||||||
assert num_outputs == seq_group_meta.state.num_steps, \
|
|
||||||
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
|
|
||||||
# This sequence is a prompt during the first step only.
|
|
||||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
if len(ctx.output_queue) == 0:
|
if len(ctx.output_queue) == 0:
|
||||||
@ -1090,7 +1071,7 @@ class LLMEngine:
|
|||||||
seq_group_meta = seq_group_metadata_list[i]
|
seq_group_meta = seq_group_metadata_list[i]
|
||||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group: SequenceGroup = scheduled_seq_group.seq_group
|
||||||
|
|
||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
finished_before.append(i)
|
finished_before.append(i)
|
||||||
@ -1101,14 +1082,14 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
output = [outputs_by_sequence_group[0][i]]
|
output = [outputs_by_sequence_group[0][i]]
|
||||||
|
|
||||||
if not is_async and seq_group_meta.is_prompt:
|
if not is_async:
|
||||||
# Updates for all decodes happen when we actually append the
|
if self.scheduler_config.is_multi_step:
|
||||||
# token ids to the seq in process_outputs.
|
# Updates happen only if the sequence is prefill
|
||||||
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
|
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||||
len(output),
|
seq_group, seq_group_meta, is_first_step_output)
|
||||||
is_first_step_output)
|
else:
|
||||||
elif not is_async:
|
seq_group.update_num_computed_tokens(
|
||||||
seq_group.update_num_computed_tokens(1)
|
seq_group_meta.token_chunk_size)
|
||||||
|
|
||||||
if outputs:
|
if outputs:
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
@ -1132,16 +1113,8 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||||
if seq_group_meta.do_sample:
|
if seq_group_meta.do_sample:
|
||||||
output_token_num = self.output_processor.process_outputs(
|
self.output_processor.process_outputs(
|
||||||
seq_group, output, is_async)
|
seq_group, output, is_async)
|
||||||
if self.speculative_config:
|
|
||||||
# We -1 here because we always
|
|
||||||
# (w/o speculative decoding) add the number of
|
|
||||||
# computed tokens by one in the decoding phase.
|
|
||||||
# Therefore, we remove that one token that
|
|
||||||
# is already added.
|
|
||||||
seq_group.update_num_computed_tokens(output_token_num -
|
|
||||||
1)
|
|
||||||
|
|
||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
finished_now.append(i)
|
finished_now.append(i)
|
||||||
@ -1250,20 +1223,15 @@ class LLMEngine:
|
|||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if seq_group_metadata.is_prompt:
|
if self.scheduler_config.is_multi_step:
|
||||||
if self.scheduler_config.is_multi_step and \
|
# Updates happen only if the sequence is prefill
|
||||||
self.scheduler_config.chunked_prefill_enabled:
|
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||||
# Prompts are scheduled in multi-step only when
|
seq_group, seq_group_metadata,
|
||||||
# chunking is enabled. These prompts turn into
|
seq_group.state.num_steps == 1)
|
||||||
# decodes after the very first step. Therefore,
|
|
||||||
# we skip the update to the num_computed_tokens
|
|
||||||
# here.
|
|
||||||
seq_group.update_num_computed_tokens(1)
|
|
||||||
else:
|
else:
|
||||||
seq_group.update_num_computed_tokens(
|
seq_group.update_num_computed_tokens(
|
||||||
seq_group_metadata.token_chunk_size)
|
seq_group_metadata.token_chunk_size)
|
||||||
else:
|
|
||||||
seq_group.update_num_computed_tokens(1)
|
|
||||||
if seq_group_metadata.do_sample:
|
if seq_group_metadata.do_sample:
|
||||||
assert len(sequence_group_outputs.samples) == 1, (
|
assert len(sequence_group_outputs.samples) == 1, (
|
||||||
"Async output processor expects a single sample"
|
"Async output processor expects a single sample"
|
||||||
@ -1273,6 +1241,14 @@ class LLMEngine:
|
|||||||
|
|
||||||
assert len(seq_group.seqs) == 1
|
assert len(seq_group.seqs) == 1
|
||||||
seq = seq_group.seqs[0]
|
seq = seq_group.seqs[0]
|
||||||
|
|
||||||
|
if self.scheduler_config.is_multi_step:
|
||||||
|
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
||||||
|
) == 0
|
||||||
|
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||||
|
if not is_prefill_append:
|
||||||
|
seq_group.update_num_computed_tokens(1)
|
||||||
|
else:
|
||||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||||
|
|
||||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List
|
||||||
|
|
||||||
from vllm.config import SchedulerConfig
|
from vllm.config import SchedulerConfig
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
@ -58,14 +58,10 @@ class SequenceGroupOutputProcessor(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_outputs(self, sequence_group: SequenceGroup,
|
def process_outputs(self, sequence_group: SequenceGroup,
|
||||||
outputs: List[SequenceGroupOutput],
|
outputs: List[SequenceGroupOutput],
|
||||||
is_async: bool) -> Optional[int]:
|
is_async: bool) -> None:
|
||||||
"""Process new token ids for the sequence group. Handles logic such as
|
"""Process new token ids for the sequence group. Handles logic such as
|
||||||
detokenization, stop checking, and freeing/forking sequences in the
|
detokenization, stop checking, and freeing/forking sequences in the
|
||||||
scheduler.
|
scheduler.
|
||||||
|
|
||||||
Return the number of new tokens generated in the sequence group.
|
|
||||||
The returned value is optional because it is only used for
|
|
||||||
speculative decoding mqa scorer.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List
|
||||||
|
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
from vllm.engine.output_processor.interfaces import (
|
from vllm.engine.output_processor.interfaces import (
|
||||||
@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
def process_outputs(self,
|
def process_outputs(self,
|
||||||
sequence_group: SequenceGroup,
|
sequence_group: SequenceGroup,
|
||||||
outputs: List[SequenceGroupOutput],
|
outputs: List[SequenceGroupOutput],
|
||||||
is_async: bool = False) -> Optional[int]:
|
is_async: bool = False) -> None:
|
||||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||||
|
|
||||||
This only supports sequence groups of size 1. It supports greater than
|
This only supports sequence groups of size 1. It supports greater than
|
||||||
@ -84,10 +84,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
tokens from the previous step. If this is true, then
|
tokens from the previous step. If this is true, then
|
||||||
no tokens need to be appended since it is already done
|
no tokens need to be appended since it is already done
|
||||||
externally (before the next schedule() call)
|
externally (before the next schedule() call)
|
||||||
|
|
||||||
Returns:
|
|
||||||
The number of tokens appended to the sequence. This is optional
|
|
||||||
because only speculative decode uses this return value.
|
|
||||||
"""
|
"""
|
||||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||||
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
||||||
@ -110,7 +106,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
# was already appended, so we only need to do the rest of the
|
# was already appended, so we only need to do the rest of the
|
||||||
# postprocessor: Detokenization + stopping logic
|
# postprocessor: Detokenization + stopping logic
|
||||||
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||||
return None
|
|
||||||
else:
|
else:
|
||||||
# Standard multi-step case
|
# Standard multi-step case
|
||||||
|
|
||||||
@ -126,7 +121,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
]
|
]
|
||||||
assert valid_samples
|
assert valid_samples
|
||||||
|
|
||||||
return self._process_seq_outputs(seq, valid_samples,
|
self._process_seq_outputs(seq, valid_samples,
|
||||||
sequence_group.sampling_params)
|
sequence_group.sampling_params)
|
||||||
|
|
||||||
def _process_decode_and_stop(self, seq: Sequence,
|
def _process_decode_and_stop(self, seq: Sequence,
|
||||||
@ -145,7 +140,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
|
|
||||||
def _process_seq_outputs(self, seq: Sequence,
|
def _process_seq_outputs(self, seq: Sequence,
|
||||||
valid_samples: List[SequenceOutput],
|
valid_samples: List[SequenceOutput],
|
||||||
sampling_params: SamplingParams) -> int:
|
sampling_params: SamplingParams) -> None:
|
||||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||||
|
|
||||||
@ -168,6 +163,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
output_token_ids = output_token_ids[:i + 1]
|
output_token_ids = output_token_ids[:i + 1]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
||||||
# Incrementally append tokens to the sequence, as if we had only one new
|
# Incrementally append tokens to the sequence, as if we had only one new
|
||||||
# token.
|
# token.
|
||||||
for output_token_id, output_logprob in zip(output_token_ids,
|
for output_token_id, output_logprob in zip(output_token_ids,
|
||||||
@ -177,8 +173,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
logprobs=output_logprob,
|
logprobs=output_logprob,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_prefill_sampled_token:
|
||||||
|
is_prefill_sampled_token = False
|
||||||
|
else:
|
||||||
|
# Update num_computed_tokens iff the sampled token is not from
|
||||||
|
# a prefill step.
|
||||||
|
seq.data.update_num_computed_tokens(1)
|
||||||
|
|
||||||
self._process_decode_and_stop(seq, sampling_params)
|
self._process_decode_and_stop(seq, sampling_params)
|
||||||
|
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
break
|
break
|
||||||
return len(output_token_ids)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user