[Bugfix] Remove the last EOS token unless explicitly specified (#5077)
This commit is contained in:
parent
5ae5ed1e60
commit
dfba529b40
86
tests/engine/output_processor/test_stop_checker.py
Normal file
86
tests/engine/output_processor/test_stop_checker.py
Normal file
@ -0,0 +1,86 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, Sequence, SequenceStatus
|
||||
|
||||
|
||||
def sequence_with_eos(text: str, eos_token: str,
|
||||
eos_token_id: int) -> Sequence:
|
||||
"""
|
||||
Create a Sequence that ends with an EOS token.
|
||||
"""
|
||||
seq = Sequence(
|
||||
seq_id=0,
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
block_size=16,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
seq.output_text = text + eos_token
|
||||
|
||||
offset = eos_token_id + 1
|
||||
for i in range(offset, len(text) + offset):
|
||||
seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)})
|
||||
seq.append_token_id(token_id=eos_token_id,
|
||||
logprobs={eos_token_id: Logprob(0.0)})
|
||||
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
return seq
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
|
||||
("This text ends with EOS token", "</s>", 2),
|
||||
])
|
||||
@pytest.mark.parametrize("ignore_eos", [True, False, None])
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
||||
ignore_eos: bool, include_stop_str_in_output: bool):
|
||||
"""
|
||||
Test the behavior of the StopChecker's maybe_stop_sequence method
|
||||
when an EOS token is encountered.
|
||||
|
||||
This test covers:
|
||||
- When the EOS token should stop the sequence and be removed from the output
|
||||
- When the EOS token should stop the sequence and be included in the output
|
||||
- When the EOS token should be ignored, and the sequence continues
|
||||
"""
|
||||
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||
get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
|
||||
stop_checker = StopChecker(max_model_len=1024,
|
||||
get_tokenizer_for_seq=get_tokenizer_for_seq)
|
||||
|
||||
seq = sequence_with_eos(
|
||||
text=text_wo_eos,
|
||||
eos_token=eos_token,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
new_char_count = len(eos_token)
|
||||
|
||||
# Note that `stop` and `stop_token_ids` are not specified
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=1,
|
||||
ignore_eos=ignore_eos,
|
||||
include_stop_str_in_output=include_stop_str_in_output)
|
||||
|
||||
stop_checker.maybe_stop_sequence(
|
||||
seq=seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
if ignore_eos:
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
assert seq.output_text == text_wo_eos + eos_token
|
||||
elif include_stop_str_in_output:
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.output_text == text_wo_eos + eos_token
|
||||
else:
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.output_text == text_wo_eos
|
@ -48,6 +48,11 @@ class StopChecker:
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
# Remove the last EOS token unless explicitly specified
|
||||
# This prevents unintended exposure of the EOS token
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user