[Core] [Frontend] Make detokenization optional (#3749)
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
498eb5cfa3
commit
aabe8f40f2
32
tests/engine/test_detokenization.py
Normal file
32
tests/engine/test_detokenization.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
def test_computed_prefix_blocks(model: str):
|
||||
# This test checks if the engine generates completions both with and
|
||||
# without optional detokenization, that detokenization includes text
|
||||
# and no-detokenization doesn't, and that both completions have the same
|
||||
# token_ids.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
|
||||
llm = LLM(model=model)
|
||||
sampling_params = SamplingParams(max_tokens=10,
|
||||
temperature=0.0,
|
||||
detokenize=False)
|
||||
|
||||
outputs_no_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
sampling_params.detokenize = True
|
||||
outputs_with_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
|
||||
assert outputs_no_detokenization.text == ''
|
||||
assert outputs_with_detokenization.text != ''
|
||||
assert outputs_no_detokenization.token_ids == \
|
||||
outputs_with_detokenization.token_ids
|
@ -432,7 +432,7 @@ class LLMEngine:
|
||||
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None:
|
||||
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
@ -478,8 +478,9 @@ class LLMEngine:
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self.detokenizer.decode_sequence_inplace(seq,
|
||||
seq_group.sampling_params)
|
||||
if seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_sequence_inplace(
|
||||
seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
@ -791,12 +792,13 @@ class LLMEngine:
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
if sampling_params.detokenize:
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
||||
|
@ -88,6 +88,7 @@ class SamplingParams:
|
||||
log probability of the sampled token, so there may be up to
|
||||
`logprobs+1` elements in the response.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
detokenize: Whether to detokenize the output. Defaults to True.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens in the output. Defaults to True.
|
||||
@ -118,6 +119,7 @@ class SamplingParams:
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
@ -150,6 +152,10 @@ class SamplingParams:
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
self.detokenize = detokenize
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
@ -210,6 +216,10 @@ class SamplingParams:
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
if self.stop and not self.detokenize:
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop.")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user