[Bugfix] Remove the last EOS token unless explicitly specified (#5077)
This commit is contained in:
parent
5ae5ed1e60
commit
dfba529b40
86
tests/engine/output_processor/test_stop_checker.py
Normal file
86
tests/engine/output_processor/test_stop_checker.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import Logprob, Sequence, SequenceStatus
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_with_eos(text: str, eos_token: str,
|
||||||
|
eos_token_id: int) -> Sequence:
|
||||||
|
"""
|
||||||
|
Create a Sequence that ends with an EOS token.
|
||||||
|
"""
|
||||||
|
seq = Sequence(
|
||||||
|
seq_id=0,
|
||||||
|
prompt="",
|
||||||
|
prompt_token_ids=[],
|
||||||
|
block_size=16,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
)
|
||||||
|
seq.output_text = text + eos_token
|
||||||
|
|
||||||
|
offset = eos_token_id + 1
|
||||||
|
for i in range(offset, len(text) + offset):
|
||||||
|
seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)})
|
||||||
|
seq.append_token_id(token_id=eos_token_id,
|
||||||
|
logprobs={eos_token_id: Logprob(0.0)})
|
||||||
|
|
||||||
|
seq.status = SequenceStatus.RUNNING
|
||||||
|
|
||||||
|
return seq
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
|
||||||
|
("This text ends with EOS token", "</s>", 2),
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("ignore_eos", [True, False, None])
|
||||||
|
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
||||||
|
ignore_eos: bool, include_stop_str_in_output: bool):
|
||||||
|
"""
|
||||||
|
Test the behavior of the StopChecker's maybe_stop_sequence method
|
||||||
|
when an EOS token is encountered.
|
||||||
|
|
||||||
|
This test covers:
|
||||||
|
- When the EOS token should stop the sequence and be removed from the output
|
||||||
|
- When the EOS token should stop the sequence and be included in the output
|
||||||
|
- When the EOS token should be ignored, and the sequence continues
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||||
|
get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
|
||||||
|
stop_checker = StopChecker(max_model_len=1024,
|
||||||
|
get_tokenizer_for_seq=get_tokenizer_for_seq)
|
||||||
|
|
||||||
|
seq = sequence_with_eos(
|
||||||
|
text=text_wo_eos,
|
||||||
|
eos_token=eos_token,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
)
|
||||||
|
new_char_count = len(eos_token)
|
||||||
|
|
||||||
|
# Note that `stop` and `stop_token_ids` are not specified
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
min_tokens=1,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
|
include_stop_str_in_output=include_stop_str_in_output)
|
||||||
|
|
||||||
|
stop_checker.maybe_stop_sequence(
|
||||||
|
seq=seq,
|
||||||
|
new_char_count=new_char_count,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ignore_eos:
|
||||||
|
assert seq.status == SequenceStatus.RUNNING
|
||||||
|
assert seq.output_text == text_wo_eos + eos_token
|
||||||
|
elif include_stop_str_in_output:
|
||||||
|
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||||
|
assert seq.output_text == text_wo_eos + eos_token
|
||||||
|
else:
|
||||||
|
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||||
|
assert seq.output_text == text_wo_eos
|
@ -48,6 +48,11 @@ class StopChecker:
|
|||||||
# Check if the sequence has generated the EOS token.
|
# Check if the sequence has generated the EOS token.
|
||||||
if ((not sampling_params.ignore_eos)
|
if ((not sampling_params.ignore_eos)
|
||||||
and seq.get_last_token_id() == seq.eos_token_id):
|
and seq.get_last_token_id() == seq.eos_token_id):
|
||||||
|
# Remove the last EOS token unless explicitly specified
|
||||||
|
# This prevents unintended exposure of the EOS token
|
||||||
|
if new_char_count and (
|
||||||
|
not sampling_params.include_stop_str_in_output):
|
||||||
|
seq.output_text = seq.output_text[:-new_char_count]
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user