[BugFix] Fix handling of stop strings and stop token ids (#3672)
This commit is contained in:
parent
1e96c3341a
commit
e46a60aa4c
@ -401,7 +401,7 @@ class VllmRunner:
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
2. One of the provided stop tokens
|
||||
3. The EOS token
|
||||
|
||||
Run `pytest tests/samplers/test_stop_reason.py`.
|
||||
Run `pytest tests/engine/test_stop_reason.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
111
tests/engine/test_stop_strings.py
Normal file
111
tests/engine/test_stop_strings.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import CompletionOutput, LLMEngine, SamplingParams
|
||||
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_model(vllm_runner):
|
||||
return vllm_runner(MODEL)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
|
||||
|
||||
def _test_stopping(llm_engine: LLMEngine,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
llm_engine.add_request(
|
||||
"id", "A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_in_output,
|
||||
), None)
|
||||
|
||||
output: Optional[CompletionOutput] = None
|
||||
output_text = ""
|
||||
stop_reason = None
|
||||
while llm_engine.has_unfinished_requests():
|
||||
(request_output, ) = llm_engine.step()
|
||||
(output, ) = request_output.outputs
|
||||
|
||||
# Ensure we don't backtrack
|
||||
assert output.text.startswith(output_text)
|
||||
output_text = output.text
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
assert output is not None
|
||||
assert output_text == expected_output
|
||||
assert stop_reason == expected_reason
|
@ -501,9 +501,11 @@ class LLMEngine:
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_sequence_inplace(
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self._check_stop(seq, new_char_count, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
@ -798,9 +800,45 @@ class LLMEngine:
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
)
|
||||
|
||||
def _check_stop(self, seq: Sequence,
|
||||
def _check_stop(self, seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
@ -811,43 +849,37 @@ class LLMEngine:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
return None
|
||||
|
||||
if sampling_params.detokenize:
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
||||
last_token_id)
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
def _finalize_sequence(self, seq: Sequence,
|
||||
sampling_params: SamplingParams,
|
||||
stop_string: str) -> None:
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
return
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
|
||||
if stop_string and seq.output_text.endswith(stop_string):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_string)]
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return None
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
@ -112,8 +112,10 @@ class RequestOutput:
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
||||
outputs = [
|
||||
CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
CompletionOutput(seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(),
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
|
@ -166,6 +166,13 @@ class SamplingParams:
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
# Number of characters to hold back for stop string evaluation
|
||||
# until sequence is finished.
|
||||
if self.stop and not include_stop_str_in_output:
|
||||
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
else:
|
||||
self.output_text_buffer_length = 0
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
@ -226,6 +233,8 @@ class SamplingParams:
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
if any(not stop_str for stop_str in self.stop):
|
||||
raise ValueError("stop cannot contain an empty string.")
|
||||
if self.stop and not self.detokenize:
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
|
@ -235,6 +235,12 @@ class Sequence:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
def get_output_text_to_return(self, buffer_length: int):
|
||||
# We return the full output text if the sequence is finished.
|
||||
truncate = buffer_length and not self.is_finished()
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# TODO This can produce incorrect hash when block size > prompt size
|
||||
|
||||
|
@ -87,12 +87,15 @@ class Detokenizer:
|
||||
prev_tokens.extend(next_iter_tokens)
|
||||
|
||||
def decode_sequence_inplace(self, seq: Sequence,
|
||||
prms: SamplingParams) -> None:
|
||||
prms: SamplingParams) -> int:
|
||||
"""Decodes the new token for a sequence. In-place operation.
|
||||
|
||||
Args:
|
||||
seq: The sequence to decode.
|
||||
prms: The sampling parameters used to generate the sequence.
|
||||
|
||||
Returns:
|
||||
The number of characters added to the output text.
|
||||
"""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
@ -151,6 +154,8 @@ class Detokenizer:
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_decoded_token_text
|
||||
|
||||
return len(new_decoded_token_text)
|
||||
|
||||
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
|
Loading…
x
Reference in New Issue
Block a user