[Core] [Frontend] Make detokenization optional (#3749)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Matthias Gerstgrasser 2024-04-03 21:52:18 -07:00 committed by GitHub
parent 498eb5cfa3
commit aabe8f40f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 9 deletions

View File

@ -0,0 +1,32 @@
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_computed_prefix_blocks(model: str):
# This test checks if the engine generates completions both with and
# without optional detokenization, that detokenization includes text
# and no-detokenization doesn't, and that both completions have the same
# token_ids.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
llm = LLM(model=model)
sampling_params = SamplingParams(max_tokens=10,
temperature=0.0,
detokenize=False)
outputs_no_detokenization = llm.generate(prompt,
sampling_params)[0].outputs[0]
sampling_params.detokenize = True
outputs_with_detokenization = llm.generate(prompt,
sampling_params)[0].outputs[0]
assert outputs_no_detokenization.text == ''
assert outputs_with_detokenization.text != ''
assert outputs_no_detokenization.token_ids == \
outputs_with_detokenization.token_ids

View File

@ -432,7 +432,7 @@ class LLMEngine:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
@ -478,8 +478,9 @@ class LLMEngine:
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self.detokenizer.decode_sequence_inplace(seq,
seq_group.sampling_params)
if seq_group.sampling_params.detokenize:
self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
@ -791,6 +792,7 @@ class LLMEngine:
if seq.get_output_len() < sampling_params.min_tokens:
return
if sampling_params.detokenize:
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)

View File

@ -88,6 +88,7 @@ class SamplingParams:
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
@ -118,6 +119,7 @@ class SamplingParams:
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
@ -150,6 +152,10 @@ class SamplingParams:
self.min_tokens = min_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
@ -210,6 +216,10 @@ class SamplingParams:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
def _verify_beam_search(self) -> None:
if self.best_of == 1: