676 lines
29 KiB
Python
676 lines
29 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import math
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
|
|
import pytest
|
|
|
|
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
|
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
|
STOP_STRINGS,
|
|
DummyOutputProcessorTestVectors,
|
|
MockEngineCore)
|
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
|
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.v1.engine import EngineCoreRequest
|
|
from vllm.v1.engine.output_processor import OutputProcessor
|
|
from vllm.v1.metrics.stats import IterationStats
|
|
|
|
|
|
def _ref_convert_id_to_token(
|
|
tokenizer: AnyTokenizer,
|
|
token_id: int,
|
|
) -> str:
|
|
"""Reference impl of logprobs detokenization.
|
|
|
|
Args:
|
|
tokenizer: tokenizer used by the model under test
|
|
token_id: convert this token id
|
|
|
|
Returns:
|
|
String representation of input token id
|
|
"""
|
|
return tokenizer.convert_ids_to_tokens(token_id) or ""
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"request_output_kind",
|
|
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
|
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
|
dummy_test_vectors):
|
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
|
log_stats=False)
|
|
engine_core = MockEngineCore(
|
|
tokens_list=dummy_test_vectors.generation_tokens)
|
|
|
|
# Make N requests.
|
|
requests = [
|
|
EngineCoreRequest(request_id=f"request-{idx}",
|
|
prompt=prompt,
|
|
prompt_token_ids=prompt_tokens,
|
|
arrival_time=0,
|
|
mm_inputs=None,
|
|
mm_hashes=None,
|
|
mm_placeholders=None,
|
|
eos_token_id=None,
|
|
lora_request=None,
|
|
sampling_params=SamplingParams(
|
|
skip_special_tokens=False,
|
|
spaces_between_special_tokens=False,
|
|
output_kind=request_output_kind,
|
|
stop=[],
|
|
include_stop_str_in_output=False,
|
|
)) for idx, (prompt, prompt_tokens) in enumerate(
|
|
zip(dummy_test_vectors.prompt_strings,
|
|
dummy_test_vectors.prompt_tokens))
|
|
]
|
|
|
|
# Add requests to the detokenizer.
|
|
for request in requests:
|
|
output_processor.add_request(request)
|
|
|
|
gen_strings = {}
|
|
gen_tokens = {}
|
|
while True:
|
|
# Mock output from the EngineCore.
|
|
outputs = engine_core.get_outputs()
|
|
if len(outputs) == 0:
|
|
break
|
|
|
|
# Step the Detokenizer.
|
|
processed_outputs = output_processor.process_outputs(outputs)
|
|
request_outputs = processed_outputs.request_outputs
|
|
requests_to_abort = processed_outputs.reqs_to_abort
|
|
assert len(requests_to_abort) == 0
|
|
|
|
# Update tracking.
|
|
for request_output in request_outputs:
|
|
request_id = request_output.request_id
|
|
new_text = request_output.outputs[0].text
|
|
new_tokens = request_output.outputs[0].token_ids
|
|
if request_id not in gen_strings:
|
|
gen_strings[request_id] = new_text
|
|
gen_tokens[request_id] = new_tokens
|
|
else:
|
|
gen_strings[request_id] += new_text
|
|
gen_tokens[request_id].extend(new_tokens)
|
|
|
|
# Confirmed tracked values matches what we expected.
|
|
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
|
zip(dummy_test_vectors.generation_strings,
|
|
dummy_test_vectors.generation_tokens)):
|
|
gen_str = gen_strings[f"request-{idx}"]
|
|
gen_toks = gen_tokens[f"request-{idx}"]
|
|
|
|
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
|
|
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"
|
|
|
|
assert output_processor.get_num_unfinished_requests() == 0
|
|
assert not output_processor.has_unfinished_requests()
|
|
|
|
|
|
def _validate_logprobs(
|
|
gen_tokens: Dict[str, List[int]],
|
|
gen_logprobs: Dict[str, Optional[SampleLogprobs]],
|
|
gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]],
|
|
gen_cumulative_logprob: Dict[str, float],
|
|
dtv: DummyOutputProcessorTestVectors,
|
|
request_id_list: List[str],
|
|
num_sample_logprobs: Optional[int],
|
|
num_prompt_logprobs: Optional[int],
|
|
) -> None:
|
|
for req_idx, req_id in enumerate(request_id_list):
|
|
new_tokens = gen_tokens[req_id]
|
|
logprobs = gen_logprobs[req_id]
|
|
prompt_logprobs = gen_prompt_logprobs[req_id]
|
|
cumulative_logprob = gen_cumulative_logprob[req_id]
|
|
prompt_token_ids = dtv.prompt_tokens[req_idx]
|
|
ref_logprobs = dtv.generation_logprobs[req_idx]
|
|
ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
|
|
if num_sample_logprobs is not None:
|
|
# Validate sample logprobs
|
|
assert logprobs is not None, (f"Request {req_id} requires sample"
|
|
" logprobs but sample logprobs are"
|
|
" None.")
|
|
# Require num sampled tokens to match num
|
|
# sampled logprobs - especially important
|
|
# to check since the detokenizer can cause
|
|
# a request to finish early due to a stop
|
|
# string being hit
|
|
num_new_tokens = len(new_tokens)
|
|
len_sample_logprobs = len(logprobs)
|
|
assert num_new_tokens == len_sample_logprobs, (
|
|
f"Request {req_id} has {num_new_tokens}"
|
|
" completion tokens but has"
|
|
f" {len_sample_logprobs} sample logprobs.")
|
|
ref_cumulative_logprob = 0.0
|
|
for idx, (sampled_token,
|
|
pos_logprob_dict) in enumerate(zip(new_tokens,
|
|
logprobs)):
|
|
# Break out the reference log probability value &
|
|
# logprob token id tensors associated with this
|
|
# position in the completion. Also break out the
|
|
# sampled token ranks
|
|
(ref_pos_logprob_toks, ref_pos_logprob_vals,
|
|
ref_sampled_token_rank) = ref_logprobs[idx]
|
|
# For each position in the completion sequence,
|
|
# ensure the actual sampled token is among the
|
|
# logprobs
|
|
assert sampled_token in pos_logprob_dict, (
|
|
f"Sampled token {sampled_token} not"
|
|
f" present in logprob at index {idx}")
|
|
|
|
# Validate number of sample logprobs
|
|
num_lp_toks = len(pos_logprob_dict)
|
|
assert (num_lp_toks == num_sample_logprobs
|
|
or num_lp_toks == num_sample_logprobs +
|
|
1), ("Valid numbers of sample logprobs are"
|
|
f" {num_sample_logprobs} or"
|
|
f" {num_sample_logprobs+1} but"
|
|
f" {num_lp_toks} logprobs found at"
|
|
f" position {idx}. Logprobs dict:"
|
|
f" {pos_logprob_dict}")
|
|
|
|
# Validate sampled token logprob rank
|
|
smp_lp = pos_logprob_dict[sampled_token]
|
|
smp_lp_rank = smp_lp.rank
|
|
assert (ref_sampled_token_rank == smp_lp_rank), (
|
|
"Sampled token logprob rank"
|
|
f" {smp_lp_rank} does not match"
|
|
" correct value"
|
|
f" {ref_sampled_token_rank}"
|
|
f" in Logprob {smp_lp}")
|
|
|
|
# Validate that the logprob processor yields
|
|
# the correct log probabilities and valid
|
|
# rankings
|
|
rank_one_appears = False
|
|
for jdx in range(1, len(ref_pos_logprob_toks)):
|
|
# Iterate over the (logprob val,logprob tok id)
|
|
# pairs expected by the test fixture at this
|
|
# position in the completion.
|
|
ref_lp_val = ref_pos_logprob_vals[jdx]
|
|
ref_tok_id = ref_pos_logprob_toks[jdx]
|
|
assert ref_tok_id in pos_logprob_dict, (
|
|
f"Expected token {ref_tok_id} to be"
|
|
f" in logprob dict but it is not.")
|
|
|
|
# Extract actually-generated logprob
|
|
# info
|
|
lp = pos_logprob_dict[ref_tok_id]
|
|
lp_val = lp.logprob
|
|
lp_rank = lp.rank
|
|
|
|
# A "top" (rank 1) logprob must be
|
|
# present
|
|
rank_one_appears = (True
|
|
if lp_rank == 1 else rank_one_appears)
|
|
|
|
# Rank must be >= 1
|
|
assert lp_rank >= 1, (f"Logprob {lp} has invalid"
|
|
f" rank {lp_rank} < 1."
|
|
f" Logprob dict: {pos_logprob_dict}")
|
|
|
|
# Validate log probability
|
|
assert math.isclose(lp_val, ref_lp_val), (
|
|
f"Token id {ref_tok_id} appears in logprobs dict"
|
|
f" at position {idx} in completion with log"
|
|
f" probability {lp_val} but {ref_lp_val} was"
|
|
f" expected. Logprob: {lp}")
|
|
|
|
assert rank_one_appears, (f"No Logprob has rank 1"
|
|
" in the following Logprob"
|
|
f" dict: {pos_logprob_dict}")
|
|
|
|
# Validate logprobs detokenization
|
|
for lp_tok in pos_logprob_dict:
|
|
# Confirm that sample logprob decoded token matches
|
|
# the logprob token id at this sequence position
|
|
decoded_token = pos_logprob_dict[lp_tok].decoded_token
|
|
ref_decoded_token = _ref_convert_id_to_token(
|
|
dtv.tokenizer, lp_tok)
|
|
assert decoded_token == ref_decoded_token, (
|
|
f"Sampled logprob token id {lp_tok} decodes to"
|
|
f" {ref_decoded_token} but Logprob decoded"
|
|
f" token is {decoded_token} instead"
|
|
f" (at position {idx})")
|
|
|
|
ref_cumulative_logprob += pos_logprob_dict[
|
|
sampled_token].logprob
|
|
# Assert that cumulative logprobs are correct
|
|
assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
|
|
else:
|
|
# Sample logprobs disabled for this request
|
|
assert logprobs is None
|
|
assert cumulative_logprob is None
|
|
|
|
if num_prompt_logprobs is not None:
|
|
# Validate prompt logprobs
|
|
assert prompt_logprobs is not None, (
|
|
f"Request {req_id} requires prompt"
|
|
" logprobs but prompt logprobs are"
|
|
" None.")
|
|
# Require num prompt tokens to match num
|
|
# prompt logprobs
|
|
num_prompt_tokens = len(prompt_token_ids)
|
|
len_prompt_logprobs = len(prompt_logprobs)
|
|
assert num_prompt_tokens == len_prompt_logprobs, (
|
|
f"Request {req_id} has {num_prompt_tokens}"
|
|
" prompt tokens but has"
|
|
f" {len_prompt_logprobs} prompt logprobs.")
|
|
# First prompt logprob is None
|
|
first_plp_dict = prompt_logprobs[0]
|
|
assert first_plp_dict is None, (
|
|
f"Request {req_id} first prompt logprob"
|
|
f" should be None but has following value"
|
|
f" instead: {first_plp_dict}")
|
|
# Break out the reference prompt log prob value &
|
|
# logprob token id matrices for the whole prompt.
|
|
# Also break out the prompt token rank vector
|
|
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
|
|
ref_prompt_token_ranks) = ref_prompt_logprobs
|
|
for idx, (prompt_token, pos_logprob_dict) in enumerate(
|
|
zip(prompt_token_ids[1:], prompt_logprobs[1:])):
|
|
|
|
# Break out the reference prompt log prob value
|
|
# vector, prompt logprob token id vector, and
|
|
# prompt token rank at the current position.
|
|
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
|
|
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
|
|
ref_prompt_logprob_vals[idx, :],
|
|
ref_prompt_token_ranks[idx])
|
|
|
|
# For each position in the prompt sequence,
|
|
# ensure the actual prompt token is among the
|
|
# logprobs
|
|
assert prompt_token in pos_logprob_dict, (
|
|
f"Prompt token {prompt_token} not"
|
|
f" present in logprob at index {idx}")
|
|
# Validate number of prompt logprobs
|
|
num_plp_toks = len(pos_logprob_dict)
|
|
assert (num_plp_toks == num_prompt_logprobs
|
|
or num_plp_toks == num_prompt_logprobs +
|
|
1), ("Valid numbers of prompt logprobs are"
|
|
f" {num_prompt_logprobs} or"
|
|
f" {num_prompt_logprobs+1} but"
|
|
f" {num_plp_toks} logprobs found at"
|
|
f" position {idx}. Logprobs dict:"
|
|
f" {pos_logprob_dict}")
|
|
|
|
# Validate prompt token logprob rank
|
|
prmpt_tok_lp = pos_logprob_dict[prompt_token]
|
|
prmpt_tok_lp_rank = prmpt_tok_lp.rank
|
|
ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
|
|
assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), (
|
|
"Prompt token logprob rank"
|
|
f" {prmpt_tok_lp_rank} does not match"
|
|
" correct value"
|
|
f" {ref_prmpt_tok_lp_rank}"
|
|
f" in Logprob {prmpt_tok_lp}")
|
|
|
|
# Validate that the logprob processor yields
|
|
# the correct prompt log probs and valid
|
|
# rankings
|
|
rank_one_appears = False
|
|
for jdx in range(1, len(ref_pos_prompt_logprob_toks)):
|
|
# Iterate over the (logprob val,logprob tok id)
|
|
# pairs expected by the test fixture at this
|
|
# position in the completion.
|
|
ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx])
|
|
ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
|
|
assert ref_tok_id in pos_logprob_dict, (
|
|
f"Expected token {ref_tok_id} to be"
|
|
f" in logprob dict but it is not.")
|
|
|
|
# Extract actually-generated logprob
|
|
# info
|
|
plp = pos_logprob_dict[ref_tok_id]
|
|
plp_val = plp.logprob
|
|
plp_rank = plp.rank
|
|
|
|
# A "top" (rank 1) logprob must be
|
|
# present
|
|
rank_one_appears = (True
|
|
if plp_rank == 1 else rank_one_appears)
|
|
|
|
# Rank must be >= 1
|
|
assert plp_rank >= 1, (
|
|
f"Logprob {plp} has invalid"
|
|
f" rank {plp_rank} < 1."
|
|
f" Logprob dict: {pos_logprob_dict}")
|
|
|
|
# Validate log probability
|
|
assert math.isclose(plp_val, ref_plp_val), (
|
|
f"Token id {ref_tok_id} appears in logprobs dict"
|
|
f" at position {idx} in completion with log"
|
|
f" probability {plp_val} but {ref_plp_val} was"
|
|
f" expected. Logprob: {plp}")
|
|
|
|
assert rank_one_appears, (f"No Logprob has rank 1"
|
|
" in the following Logprob"
|
|
f" dict: {pos_logprob_dict}")
|
|
|
|
# Validate prompt logprob detokenization
|
|
for plp_tok in pos_logprob_dict:
|
|
# Confirm that prompt logprob decoded token matches
|
|
# the logprob token id at this sequence position
|
|
decoded_token = pos_logprob_dict[plp_tok].decoded_token
|
|
ref_decoded_token = _ref_convert_id_to_token(
|
|
dtv.tokenizer, plp_tok)
|
|
assert decoded_token == ref_decoded_token, (
|
|
f"Prompt logprob token id {plp_tok} decodes to"
|
|
f" {ref_decoded_token} but Logprob decoded"
|
|
f" token is {decoded_token} instead"
|
|
f" (at position {idx})")
|
|
else:
|
|
# Prompt logprobs disabled for this request
|
|
assert prompt_logprobs is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"request_output_kind",
|
|
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
|
@pytest.mark.parametrize("num_sample_logprobs",
|
|
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
|
@pytest.mark.parametrize("num_prompt_logprobs",
|
|
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
|
def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
|
num_sample_logprobs: Optional[int],
|
|
num_prompt_logprobs: Optional[int],
|
|
dummy_test_vectors):
|
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
|
log_stats=False)
|
|
engine_core = MockEngineCore(
|
|
tokens_list=dummy_test_vectors.generation_tokens,
|
|
generated_logprobs_raw=None if num_sample_logprobs is None else
|
|
dummy_test_vectors.generation_logprobs,
|
|
prompt_logprobs_raw=None
|
|
if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs)
|
|
|
|
# Make N requests.
|
|
request_id_list = [
|
|
f"request-{idx}"
|
|
for idx in range(len(dummy_test_vectors.prompt_strings))
|
|
]
|
|
requests = [
|
|
EngineCoreRequest(request_id=request_id_list[idx],
|
|
prompt=prompt,
|
|
prompt_token_ids=prompt_tokens,
|
|
arrival_time=0,
|
|
mm_inputs=None,
|
|
mm_hashes=None,
|
|
mm_placeholders=None,
|
|
eos_token_id=None,
|
|
lora_request=None,
|
|
sampling_params=SamplingParams(
|
|
skip_special_tokens=False,
|
|
spaces_between_special_tokens=False,
|
|
output_kind=request_output_kind,
|
|
stop=[],
|
|
include_stop_str_in_output=False,
|
|
logprobs=num_sample_logprobs,
|
|
prompt_logprobs=num_prompt_logprobs,
|
|
)) for idx, (prompt, prompt_tokens) in enumerate(
|
|
zip(dummy_test_vectors.prompt_strings,
|
|
dummy_test_vectors.prompt_tokens))
|
|
]
|
|
|
|
# Add requests to the detokenizer.
|
|
for request in requests:
|
|
output_processor.add_request(request)
|
|
|
|
gen_tokens = {}
|
|
gen_logprobs = {}
|
|
gen_prompt_logprobs = {}
|
|
gen_cumulative_logprobs = {}
|
|
while True:
|
|
# Mock output from the EngineCore.
|
|
outputs = engine_core.get_outputs()
|
|
if len(outputs) == 0:
|
|
break
|
|
|
|
# Step the logprobs processor.
|
|
processed_outputs = output_processor.process_outputs(outputs)
|
|
request_outputs = processed_outputs.request_outputs
|
|
requests_to_abort = processed_outputs.reqs_to_abort
|
|
assert len(requests_to_abort) == 0
|
|
|
|
# Update tracking.
|
|
for request_output in request_outputs:
|
|
request_id = request_output.request_id
|
|
new_tokens = request_output.outputs[0].token_ids
|
|
prompt_logprobs = request_output.prompt_logprobs
|
|
logprobs = request_output.outputs[0].logprobs
|
|
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
|
0].cumulative_logprob
|
|
if request_id not in gen_logprobs:
|
|
# Start tracking sample and prompt logprobs for this request
|
|
gen_tokens[request_id] = new_tokens
|
|
gen_logprobs[request_id] = logprobs
|
|
gen_prompt_logprobs[request_id] = prompt_logprobs
|
|
else:
|
|
# Extend logprobs tracker
|
|
gen_tokens[request_id].extend(new_tokens)
|
|
lp = gen_logprobs[request_id]
|
|
plp = gen_prompt_logprobs[request_id]
|
|
if lp:
|
|
lp.extend(logprobs)
|
|
if plp:
|
|
plp.extend(prompt_logprobs)
|
|
|
|
# Confirmed tracked logprobs match what we expect
|
|
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
|
gen_cumulative_logprobs, dummy_test_vectors,
|
|
request_id_list, num_sample_logprobs,
|
|
num_prompt_logprobs)
|
|
|
|
assert output_processor.get_num_unfinished_requests() == 0
|
|
assert not output_processor.has_unfinished_requests()
|
|
|
|
|
|
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
|
@pytest.mark.parametrize("num_sample_logprobs",
|
|
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
|
@pytest.mark.parametrize("num_prompt_logprobs",
|
|
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
|
def test_stop_string(include_stop_str_in_output: bool,
|
|
num_sample_logprobs: Optional[int],
|
|
num_prompt_logprobs: Optional[int], dummy_test_vectors):
|
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
|
log_stats=False)
|
|
engine_core = MockEngineCore(
|
|
tokens_list=dummy_test_vectors.generation_tokens,
|
|
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
|
if num_sample_logprobs else None,
|
|
prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs
|
|
if num_prompt_logprobs else None)
|
|
|
|
# Make N requests.
|
|
request_id_list = [
|
|
f"request-{idx}"
|
|
for idx in range(len(dummy_test_vectors.prompt_strings))
|
|
]
|
|
requests = [
|
|
EngineCoreRequest(
|
|
request_id=request_id_list[idx],
|
|
prompt=prompt,
|
|
prompt_token_ids=prompt_tokens,
|
|
arrival_time=0,
|
|
mm_inputs=None,
|
|
mm_hashes=None,
|
|
mm_placeholders=None,
|
|
eos_token_id=None,
|
|
lora_request=None,
|
|
sampling_params=SamplingParams(
|
|
skip_special_tokens=False,
|
|
spaces_between_special_tokens=False,
|
|
output_kind=RequestOutputKind.DELTA,
|
|
stop=STOP_STRINGS,
|
|
include_stop_str_in_output=include_stop_str_in_output,
|
|
logprobs=num_sample_logprobs,
|
|
prompt_logprobs=num_prompt_logprobs,
|
|
)) for idx, (prompt, prompt_tokens) in enumerate(
|
|
zip(dummy_test_vectors.prompt_strings,
|
|
dummy_test_vectors.prompt_tokens))
|
|
]
|
|
|
|
# Add requests to the detokenizer.
|
|
for request in requests:
|
|
output_processor.add_request(request)
|
|
|
|
gen_strings = {}
|
|
gen_tokens = {}
|
|
gen_logprobs = {}
|
|
gen_prompt_logprobs = {}
|
|
gen_cumulative_logprobs = {}
|
|
aborted = []
|
|
while True:
|
|
# Mock output from the EngineCore.
|
|
outputs = engine_core.get_outputs()
|
|
if len(outputs) == 0:
|
|
break
|
|
|
|
# Step the Detokenizer.
|
|
processed_outputs = output_processor.process_outputs(outputs)
|
|
request_outputs = processed_outputs.request_outputs
|
|
requests_to_abort = processed_outputs.reqs_to_abort
|
|
for request_output in request_outputs:
|
|
# If aborted, we should not get a request output.
|
|
assert request_output.request_id not in aborted
|
|
aborted.extend(requests_to_abort)
|
|
|
|
# Update tracking.
|
|
for request_output in request_outputs:
|
|
if request_output.finished:
|
|
assert request_output.outputs[0].finish_reason == "stop"
|
|
|
|
request_id = request_output.request_id
|
|
new_text = request_output.outputs[0].text
|
|
new_tokens = request_output.outputs[0].token_ids
|
|
prompt_logprobs = request_output.prompt_logprobs
|
|
logprobs = request_output.outputs[0].logprobs
|
|
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
|
0].cumulative_logprob
|
|
if request_id not in gen_strings:
|
|
gen_strings[request_id] = new_text
|
|
gen_tokens[request_id] = new_tokens
|
|
gen_logprobs[request_id] = logprobs
|
|
gen_prompt_logprobs[request_id] = prompt_logprobs
|
|
else:
|
|
gen_strings[request_id] += new_text
|
|
gen_tokens[request_id].extend(new_tokens)
|
|
lp = gen_logprobs[request_id]
|
|
plp = gen_prompt_logprobs[request_id]
|
|
if lp:
|
|
lp.extend(logprobs)
|
|
if plp:
|
|
plp.extend(prompt_logprobs)
|
|
|
|
# Confirmed tracked values matches what we expected.
|
|
for idx, (ref_gen_str, stop_str) in enumerate(
|
|
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)):
|
|
|
|
# Request should be aborted.
|
|
request_id = f"request-{idx}"
|
|
assert request_id in aborted
|
|
|
|
# Collected values that were generated.
|
|
gen_str = gen_strings[request_id]
|
|
|
|
# Construct reference strings.
|
|
stop_str_idx = ref_gen_str.find(stop_str)
|
|
ref_str_exc_stop = ref_gen_str[:stop_str_idx]
|
|
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
|
|
|
|
if include_stop_str_in_output:
|
|
assert gen_str == ref_str_inc_stop, (
|
|
f"{gen_str=}, {ref_str_inc_stop=}")
|
|
else:
|
|
assert gen_str == ref_str_exc_stop, (
|
|
f"{gen_str=}, {ref_str_exc_stop=}")
|
|
|
|
# Confirmed tracked logprobs match what we expect
|
|
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
|
gen_cumulative_logprobs, dummy_test_vectors,
|
|
request_id_list, num_sample_logprobs,
|
|
num_prompt_logprobs)
|
|
|
|
assert output_processor.get_num_unfinished_requests() == 0
|
|
assert not output_processor.has_unfinished_requests()
|
|
|
|
|
|
def test_iteration_stats(dummy_test_vectors):
|
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
|
log_stats=True)
|
|
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
|
engine_core_timestamp = time.monotonic()
|
|
|
|
# Make N requests.
|
|
requests = [
|
|
EngineCoreRequest(
|
|
request_id=f"request-{idx}",
|
|
prompt=prompt,
|
|
prompt_token_ids=prompt_tokens,
|
|
arrival_time=0,
|
|
mm_inputs=None,
|
|
mm_hashes=None,
|
|
mm_placeholders=None,
|
|
eos_token_id=None,
|
|
lora_request=None,
|
|
sampling_params=SamplingParams(),
|
|
) for idx, (prompt, prompt_tokens) in enumerate(
|
|
zip(dummy_test_vectors.prompt_strings,
|
|
dummy_test_vectors.prompt_tokens))
|
|
]
|
|
|
|
# Add all requests except one to the OutputProcessor.
|
|
num_active = len(dummy_test_vectors.generation_tokens) - 1
|
|
for request in requests[:num_active]:
|
|
output_processor.add_request(request)
|
|
inactive_request = requests[num_active]
|
|
|
|
# First iteration has 2 prefills.
|
|
outputs = engine_core.get_outputs()[:num_active]
|
|
iteration_stats = IterationStats()
|
|
output_processor.process_outputs(outputs, engine_core_timestamp,
|
|
iteration_stats)
|
|
total_prompt_tokens = sum([
|
|
len(prompt_tokens)
|
|
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
|
])
|
|
|
|
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
|
assert iteration_stats.num_generation_tokens == num_active
|
|
|
|
# Just decodes in this step.
|
|
outputs = engine_core.get_outputs()[:num_active]
|
|
iteration_stats = IterationStats()
|
|
output_processor.process_outputs(outputs, engine_core_timestamp,
|
|
iteration_stats)
|
|
|
|
assert iteration_stats.num_prompt_tokens == 0
|
|
assert iteration_stats.num_generation_tokens == num_active
|
|
|
|
# Add a new request - prefill and 2 decodes in this step.
|
|
output_processor.add_request(inactive_request)
|
|
num_active += 1
|
|
outputs = engine_core.get_outputs()[:num_active]
|
|
iteration_stats = IterationStats()
|
|
output_processor.process_outputs(outputs, engine_core_timestamp,
|
|
iteration_stats)
|
|
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
|
|
|
|
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
|
assert iteration_stats.num_generation_tokens == num_active
|
|
|
|
# Just decodes in this step.
|
|
outputs = engine_core.get_outputs()[:num_active]
|
|
iteration_stats = IterationStats()
|
|
output_processor.process_outputs(outputs, engine_core_timestamp,
|
|
iteration_stats)
|
|
|
|
assert iteration_stats.num_prompt_tokens == 0
|
|
assert iteration_stats.num_generation_tokens == num_active
|