[V1] Detokenizer: Respect Stop Tokens + not include_stop_str_in_output (#14624)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
This commit is contained in:
parent
8a4a2efc6f
commit
02fcaa3d0a
@ -470,22 +470,184 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
|||||||
assert not output_processor.has_unfinished_requests()
|
assert not output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
|
||||||
|
[(False, "stop_token_ids", False, None),
|
||||||
|
(True, "stop_token_ids", False, None),
|
||||||
|
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||||
|
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||||
|
(False, "eos_token_id", False, None), (True, "eos_token_id", False, None),
|
||||||
|
(False, "eos_token_id", True, None)])
|
||||||
|
def test_stop_token(include_stop_str_in_output: bool,
|
||||||
|
num_sample_logprobs: Optional[int], stop_token_type: str,
|
||||||
|
ignore_eos: bool, dummy_test_vectors):
|
||||||
|
"""Test output processor EOS/stop token handling.
|
||||||
|
|
||||||
|
Send mock engine core request to mock engine core and pass core outputs
|
||||||
|
to output processor. Validate output processor tokens, text and
|
||||||
|
(if enabled) sample logprobs. Batch-size one.
|
||||||
|
|
||||||
|
The test emulates a scenario where a model outputs text tokens followed
|
||||||
|
by two identical control tokens:
|
||||||
|
<token><token>...<token><control><control>
|
||||||
|
|
||||||
|
If EOS is under test, the control tokens are EOS; otherwise, they are
|
||||||
|
some other token id.
|
||||||
|
|
||||||
|
Test behavior:
|
||||||
|
|
||||||
|
* If EOS is under test and `ignore_eos=True`, the detokenized string
|
||||||
|
should be <token><token>...<token><control><control> and the finish
|
||||||
|
reason should be "length" (i.e. no stop occurs)
|
||||||
|
|
||||||
|
* else, if `include_stop_str_in_output==True`, the detokenized
|
||||||
|
string should be <token><token>...<token><control> and the finish
|
||||||
|
reason should be "stop" (i.e. first control token causes stop
|
||||||
|
and is represented in output text)
|
||||||
|
|
||||||
|
* else, the detokenized string should be
|
||||||
|
<token><token>...<token> and the finish reason should be "stop"
|
||||||
|
(i.e. first control token causes stop but is not represented
|
||||||
|
in output text.)
|
||||||
|
|
||||||
|
Note: some test details are tuned for meta-llama/Llama-3.2-1B,
|
||||||
|
another model should work only if the test is modified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_stop_str_in_output: stop token str appears in output text
|
||||||
|
num_sample_logprobs: number of sample logprobs (`None` for no logprobs)
|
||||||
|
stop_token_type: "eos_token_id" for EOS, "stop_token_ids" for stop token
|
||||||
|
ignore_eos: if True, EOS stops are disabled
|
||||||
|
dummy_test_vectors: dummy engine core outputs and other data structures
|
||||||
|
"""
|
||||||
|
model_id = dummy_test_vectors.tokenizer.name_or_path
|
||||||
|
if model_id != 'meta-llama/Llama-3.2-1B':
|
||||||
|
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
|
||||||
|
f"{model_id} is in use.")
|
||||||
|
do_logprobs = num_sample_logprobs is not None
|
||||||
|
# EOS under test; if False, stop_token_ids under test
|
||||||
|
is_eos_test = stop_token_type == "eos_token_id"
|
||||||
|
# EOS under test but ignore_eos enabled
|
||||||
|
is_eos_ignore_test = is_eos_test and ignore_eos
|
||||||
|
eos_token_id = (
|
||||||
|
dummy_test_vectors.tokenizer.eos_token_id if is_eos_test else None
|
||||||
|
) # '<|end_of_text|>'
|
||||||
|
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
|
||||||
|
|
||||||
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||||
|
log_stats=False)
|
||||||
|
# Dummy engine core outputs, with control tokens suffixed to test stops
|
||||||
|
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
|
||||||
|
assert suffix_token is not None and isinstance(suffix_token[0], int)
|
||||||
|
generation_string = dummy_test_vectors.generation_strings[0]
|
||||||
|
generation_tokens = (dummy_test_vectors.generation_tokens[0] +
|
||||||
|
2 * suffix_token)
|
||||||
|
if do_logprobs:
|
||||||
|
generation_logprobs = (
|
||||||
|
dummy_test_vectors.generation_logprobs[0] +
|
||||||
|
2 * [dummy_test_vectors.generation_logprobs[0][-1]])
|
||||||
|
prompt_string = dummy_test_vectors.prompt_strings[0]
|
||||||
|
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
tokens_list=[generation_tokens],
|
||||||
|
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
|
||||||
|
prompt_logprobs_raw=None,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
ignore_eos=ignore_eos)
|
||||||
|
|
||||||
|
# Make request.
|
||||||
|
request_id = "request-0"
|
||||||
|
request = EngineCoreRequest(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt_string,
|
||||||
|
prompt_token_ids=prompt_tokens,
|
||||||
|
arrival_time=0,
|
||||||
|
mm_inputs=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
mm_placeholders=None,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
lora_request=None,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
skip_special_tokens=False,
|
||||||
|
spaces_between_special_tokens=False,
|
||||||
|
output_kind=RequestOutputKind.DELTA,
|
||||||
|
stop=[],
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
|
logprobs=num_sample_logprobs,
|
||||||
|
prompt_logprobs=None,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Add request to the detokenizer.
|
||||||
|
output_processor.add_request(request)
|
||||||
|
|
||||||
|
# Loop over engine core steps; run output processor
|
||||||
|
gen_string = ""
|
||||||
|
gen_tokens = []
|
||||||
|
gen_logprobs = []
|
||||||
|
while True:
|
||||||
|
# Mock output from the EngineCore.
|
||||||
|
outputs = engine_core.get_outputs()
|
||||||
|
if len(outputs) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Step the Detokenizer.
|
||||||
|
processed_outputs = output_processor.process_outputs(outputs)
|
||||||
|
request_outputs = processed_outputs.request_outputs
|
||||||
|
assert len(request_outputs) == 1
|
||||||
|
# Stop token does not rely on abort
|
||||||
|
assert not processed_outputs.reqs_to_abort
|
||||||
|
|
||||||
|
# Update tracking.
|
||||||
|
request_output = request_outputs[0]
|
||||||
|
if request_output.finished:
|
||||||
|
finish_reason = ("length" if is_eos_ignore_test else "stop")
|
||||||
|
assert request_output.outputs[0].finish_reason == finish_reason
|
||||||
|
|
||||||
|
gen_string += request_output.outputs[0].text
|
||||||
|
gen_tokens.extend(request_output.outputs[0].token_ids)
|
||||||
|
if do_logprobs:
|
||||||
|
gen_logprobs.extend(request_output.outputs[0].logprobs)
|
||||||
|
|
||||||
|
# Validate generated text
|
||||||
|
control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>'
|
||||||
|
if is_eos_ignore_test:
|
||||||
|
# Length-based stop; expect full string
|
||||||
|
ref_str = generation_string + 2 * control_token
|
||||||
|
elif include_stop_str_in_output:
|
||||||
|
# Stop token triggered; include in output
|
||||||
|
ref_str = generation_string + control_token
|
||||||
|
else:
|
||||||
|
# Stop token triggered but not in output
|
||||||
|
ref_str = generation_string
|
||||||
|
assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}")
|
||||||
|
|
||||||
|
if do_logprobs:
|
||||||
|
# Validate number of sample logprobs
|
||||||
|
num_tokens = len(gen_tokens)
|
||||||
|
num_logprobs = len(gen_logprobs)
|
||||||
|
assert num_tokens == num_logprobs, (
|
||||||
|
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})")
|
||||||
|
|
||||||
|
# Check requests are finished
|
||||||
|
assert output_processor.get_num_unfinished_requests() == 0
|
||||||
|
assert not output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||||
@pytest.mark.parametrize("num_sample_logprobs",
|
@pytest.mark.parametrize("num_sample_logprobs",
|
||||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||||
@pytest.mark.parametrize("num_prompt_logprobs",
|
|
||||||
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
|
||||||
def test_stop_string(include_stop_str_in_output: bool,
|
def test_stop_string(include_stop_str_in_output: bool,
|
||||||
num_sample_logprobs: Optional[int],
|
num_sample_logprobs: Optional[int], dummy_test_vectors):
|
||||||
num_prompt_logprobs: Optional[int], dummy_test_vectors):
|
|
||||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||||
log_stats=False)
|
log_stats=False)
|
||||||
engine_core = MockEngineCore(
|
engine_core = MockEngineCore(
|
||||||
tokens_list=dummy_test_vectors.generation_tokens,
|
tokens_list=dummy_test_vectors.generation_tokens,
|
||||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||||
if num_sample_logprobs else None,
|
if num_sample_logprobs else None,
|
||||||
prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs
|
prompt_logprobs_raw=None)
|
||||||
if num_prompt_logprobs else None)
|
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
request_id_list = [
|
request_id_list = [
|
||||||
@ -510,7 +672,7 @@ def test_stop_string(include_stop_str_in_output: bool,
|
|||||||
stop=STOP_STRINGS,
|
stop=STOP_STRINGS,
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
logprobs=num_sample_logprobs,
|
logprobs=num_sample_logprobs,
|
||||||
prompt_logprobs=num_prompt_logprobs,
|
prompt_logprobs=None,
|
||||||
)) for idx, (prompt, prompt_tokens) in enumerate(
|
)) for idx, (prompt, prompt_tokens) in enumerate(
|
||||||
zip(dummy_test_vectors.prompt_strings,
|
zip(dummy_test_vectors.prompt_strings,
|
||||||
dummy_test_vectors.prompt_tokens))
|
dummy_test_vectors.prompt_tokens))
|
||||||
@ -594,8 +756,7 @@ def test_stop_string(include_stop_str_in_output: bool,
|
|||||||
# Confirmed tracked logprobs match what we expect
|
# Confirmed tracked logprobs match what we expect
|
||||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||||
gen_cumulative_logprobs, dummy_test_vectors,
|
gen_cumulative_logprobs, dummy_test_vectors,
|
||||||
request_id_list, num_sample_logprobs,
|
request_id_list, num_sample_logprobs, None)
|
||||||
num_prompt_logprobs)
|
|
||||||
|
|
||||||
assert output_processor.get_num_unfinished_requests() == 0
|
assert output_processor.get_num_unfinished_requests() == 0
|
||||||
assert not output_processor.has_unfinished_requests()
|
assert not output_processor.has_unfinished_requests()
|
||||||
|
@ -20,7 +20,7 @@ NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5
|
|||||||
# Number of prompt logprobs to request when testing prompt logprobs
|
# Number of prompt logprobs to request when testing prompt logprobs
|
||||||
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7
|
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7
|
||||||
|
|
||||||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
TOKENIZER_NAME = "meta-llama/Llama-3.2-1B"
|
||||||
|
|
||||||
FULL_STRINGS = [
|
FULL_STRINGS = [
|
||||||
"My name is Robert from Neural Magic and I love working on vLLM so much!",
|
"My name is Robert from Neural Magic and I love working on vLLM so much!",
|
||||||
@ -330,13 +330,21 @@ class MockEngineCore:
|
|||||||
# each matrix has dimensions
|
# each matrix has dimensions
|
||||||
# (num prompt toks) x (num prompt logprobs+1)
|
# (num prompt toks) x (num prompt logprobs+1)
|
||||||
prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None,
|
prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
|
ignore_eos: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.num_requests = len(tokens_list)
|
||||||
self.tokens_list = tokens_list
|
self.tokens_list = tokens_list
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
self.generated_logprobs_raw = generated_logprobs_raw
|
self.generated_logprobs_raw = generated_logprobs_raw
|
||||||
self.do_logprobs = generated_logprobs_raw is not None
|
self.do_logprobs = generated_logprobs_raw is not None
|
||||||
self.prompt_logprobs_raw = prompt_logprobs_raw
|
self.prompt_logprobs_raw = prompt_logprobs_raw
|
||||||
self.do_prompt_logprobs = prompt_logprobs_raw is not None
|
self.do_prompt_logprobs = prompt_logprobs_raw is not None
|
||||||
|
self.request_finished = [False for _ in range(self.num_requests)]
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.stop_token_ids = stop_token_ids
|
||||||
|
self.ignore_eos = ignore_eos
|
||||||
|
|
||||||
def get_outputs(self) -> list[EngineCoreOutput]:
|
def get_outputs(self) -> list[EngineCoreOutput]:
|
||||||
do_logprobs = self.do_logprobs
|
do_logprobs = self.do_logprobs
|
||||||
@ -345,7 +353,7 @@ class MockEngineCore:
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for req_idx, token_ids in enumerate(self.tokens_list):
|
for req_idx, token_ids in enumerate(self.tokens_list):
|
||||||
if len(token_ids) > token_idx:
|
if not self.request_finished[req_idx]:
|
||||||
if do_logprobs:
|
if do_logprobs:
|
||||||
assert self.generated_logprobs_raw is not None
|
assert self.generated_logprobs_raw is not None
|
||||||
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
|
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
|
||||||
@ -365,14 +373,23 @@ class MockEngineCore:
|
|||||||
prompt_logprobs = None
|
prompt_logprobs = None
|
||||||
else:
|
else:
|
||||||
prompt_logprobs = None
|
prompt_logprobs = None
|
||||||
|
new_token_id = token_ids[token_idx]
|
||||||
output = EngineCoreOutput(
|
output = EngineCoreOutput(
|
||||||
request_id=f"request-{req_idx}",
|
request_id=f"request-{req_idx}",
|
||||||
new_token_ids=[token_ids[token_idx]],
|
new_token_ids=[new_token_id],
|
||||||
new_logprobs=logprobs,
|
new_logprobs=logprobs,
|
||||||
new_prompt_logprobs_tensors=prompt_logprobs,
|
new_prompt_logprobs_tensors=prompt_logprobs,
|
||||||
)
|
)
|
||||||
if token_idx == len(token_ids) - 1:
|
if token_idx == len(token_ids) - 1:
|
||||||
|
output.finish_reason = FinishReason.LENGTH
|
||||||
|
self.request_finished[req_idx] = True
|
||||||
|
if not self.ignore_eos and new_token_id == self.eos_token_id:
|
||||||
output.finish_reason = FinishReason.STOP
|
output.finish_reason = FinishReason.STOP
|
||||||
|
self.request_finished[req_idx] = True
|
||||||
|
if new_token_id in (self.stop_token_ids or ()):
|
||||||
|
output.finish_reason = FinishReason.STOP
|
||||||
|
output.stop_reason = new_token_id
|
||||||
|
self.request_finished[req_idx] = True
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
self.current_idx += 1
|
self.current_idx += 1
|
||||||
|
@ -88,7 +88,8 @@ class IncrementalDetokenizer:
|
|||||||
stop_buffer_length=stop_buffer_length,
|
stop_buffer_length=stop_buffer_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, new_token_ids: list[int]) -> Optional[str]:
|
def update(self, new_token_ids: list[int],
|
||||||
|
stop_terminated: bool) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Update RequestState for the request_id by:
|
Update RequestState for the request_id by:
|
||||||
1) Detokenize the new token ids incrementally.
|
1) Detokenize the new token ids incrementally.
|
||||||
@ -96,11 +97,22 @@ class IncrementalDetokenizer:
|
|||||||
|
|
||||||
Return matched stop string or None.
|
Return matched stop string or None.
|
||||||
"""
|
"""
|
||||||
|
if not new_token_ids:
|
||||||
|
# Skip detokenization if no new token ids
|
||||||
|
return None
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
|
# Skip detokenization if no tokenizer
|
||||||
self.token_ids.extend(new_token_ids)
|
self.token_ids.extend(new_token_ids)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if stop_terminated and not self.include_stop_str_in_output:
|
||||||
|
# If stop-terminated, exclude last token from detokenization
|
||||||
|
# based on include_stop_str_in_output parameter.
|
||||||
|
skipped_stop_token_id = new_token_ids[-1]
|
||||||
|
new_token_ids = new_token_ids[:-1]
|
||||||
|
else:
|
||||||
|
skipped_stop_token_id = None
|
||||||
|
|
||||||
# 1) Detokenize the new token ids incrementally.
|
# 1) Detokenize the new token ids incrementally.
|
||||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||||
# new_token_ids is more than 1. We need to optimize this.
|
# new_token_ids is more than 1. We need to optimize this.
|
||||||
@ -127,7 +139,14 @@ class IncrementalDetokenizer:
|
|||||||
|
|
||||||
self.output_text += decoded_text
|
self.output_text += decoded_text
|
||||||
|
|
||||||
# 2) Evaluate stop criteria.
|
if stop_terminated:
|
||||||
|
if skipped_stop_token_id is not None:
|
||||||
|
# Cleanup after skipping detokenization
|
||||||
|
self.token_ids.append(skipped_stop_token_id)
|
||||||
|
# Stop token triggered; skip stop string check
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2) Evaluate stop strings.
|
||||||
stop_string = None
|
stop_string = None
|
||||||
if self.stop:
|
if self.stop:
|
||||||
stop = StopChecker.check_stop_strings(
|
stop = StopChecker.check_stop_strings(
|
||||||
|
@ -299,9 +299,9 @@ class OutputProcessor:
|
|||||||
# in the EngineCore.
|
# in the EngineCore.
|
||||||
req_state.is_prefilling = not new_token_ids
|
req_state.is_prefilling = not new_token_ids
|
||||||
|
|
||||||
# 2) Detokenize the token ids into text and check for stop
|
# 2) Detokenize the token ids into text and perform stop checks.
|
||||||
# strings.
|
stop_string = req_state.detokenizer.update(
|
||||||
stop_string = req_state.detokenizer.update(new_token_ids)
|
new_token_ids, finish_reason == FinishReason.STOP)
|
||||||
if stop_string and finish_reason != FinishReason.STOP:
|
if stop_string and finish_reason != FinishReason.STOP:
|
||||||
finish_reason = FinishReason.STOP
|
finish_reason = FinishReason.STOP
|
||||||
stop_reason = stop_string
|
stop_reason = stop_string
|
||||||
|
Loading…
x
Reference in New Issue
Block a user