[Core] [Frontend] Make detokenization optional (#3749)
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
498eb5cfa3
commit
aabe8f40f2
32
tests/engine/test_detokenization.py
Normal file
32
tests/engine/test_detokenization.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_computed_prefix_blocks(model: str):
|
||||||
|
# This test checks if the engine generates completions both with and
|
||||||
|
# without optional detokenization, that detokenization includes text
|
||||||
|
# and no-detokenization doesn't, and that both completions have the same
|
||||||
|
# token_ids.
|
||||||
|
prompt = (
|
||||||
|
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||||
|
"paper clips? Is there an easy to follow video tutorial available "
|
||||||
|
"online for free?")
|
||||||
|
|
||||||
|
llm = LLM(model=model)
|
||||||
|
sampling_params = SamplingParams(max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
detokenize=False)
|
||||||
|
|
||||||
|
outputs_no_detokenization = llm.generate(prompt,
|
||||||
|
sampling_params)[0].outputs[0]
|
||||||
|
sampling_params.detokenize = True
|
||||||
|
outputs_with_detokenization = llm.generate(prompt,
|
||||||
|
sampling_params)[0].outputs[0]
|
||||||
|
|
||||||
|
assert outputs_no_detokenization.text == ''
|
||||||
|
assert outputs_with_detokenization.text != ''
|
||||||
|
assert outputs_no_detokenization.token_ids == \
|
||||||
|
outputs_with_detokenization.token_ids
|
@ -432,7 +432,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Process prompt logprobs
|
# Process prompt logprobs
|
||||||
prompt_logprobs = outputs.prompt_logprobs
|
prompt_logprobs = outputs.prompt_logprobs
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
|
||||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||||
seq_group, prompt_logprobs)
|
seq_group, prompt_logprobs)
|
||||||
seq_group.prompt_logprobs = prompt_logprobs
|
seq_group.prompt_logprobs = prompt_logprobs
|
||||||
@ -478,8 +478,9 @@ class LLMEngine:
|
|||||||
child_seqs.append((parent, parent))
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
for seq, _ in child_seqs:
|
for seq, _ in child_seqs:
|
||||||
self.detokenizer.decode_sequence_inplace(seq,
|
if seq_group.sampling_params.detokenize:
|
||||||
seq_group.sampling_params)
|
self.detokenizer.decode_sequence_inplace(
|
||||||
|
seq, seq_group.sampling_params)
|
||||||
self._check_stop(seq, seq_group.sampling_params)
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
# Non-beam search case
|
# Non-beam search case
|
||||||
@ -791,12 +792,13 @@ class LLMEngine:
|
|||||||
if seq.get_output_len() < sampling_params.min_tokens:
|
if seq.get_output_len() < sampling_params.min_tokens:
|
||||||
return
|
return
|
||||||
|
|
||||||
for stop_str in sampling_params.stop:
|
if sampling_params.detokenize:
|
||||||
if seq.output_text.endswith(stop_str):
|
for stop_str in sampling_params.stop:
|
||||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
if seq.output_text.endswith(stop_str):
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
seq.stop_reason = stop_str
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
seq.stop_reason = stop_str
|
||||||
|
return
|
||||||
last_token_id = seq.get_last_token_id()
|
last_token_id = seq.get_last_token_id()
|
||||||
if last_token_id in sampling_params.stop_token_ids:
|
if last_token_id in sampling_params.stop_token_ids:
|
||||||
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
||||||
|
@ -88,6 +88,7 @@ class SamplingParams:
|
|||||||
log probability of the sampled token, so there may be up to
|
log probability of the sampled token, so there may be up to
|
||||||
`logprobs+1` elements in the response.
|
`logprobs+1` elements in the response.
|
||||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||||
|
detokenize: Whether to detokenize the output. Defaults to True.
|
||||||
skip_special_tokens: Whether to skip special tokens in the output.
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
spaces_between_special_tokens: Whether to add spaces between special
|
spaces_between_special_tokens: Whether to add spaces between special
|
||||||
tokens in the output. Defaults to True.
|
tokens in the output. Defaults to True.
|
||||||
@ -118,6 +119,7 @@ class SamplingParams:
|
|||||||
min_tokens: int = 0,
|
min_tokens: int = 0,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
detokenize: bool = True,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||||
@ -150,6 +152,10 @@ class SamplingParams:
|
|||||||
self.min_tokens = min_tokens
|
self.min_tokens = min_tokens
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
|
# NOTE: This parameter is only exposed at the engine level for now.
|
||||||
|
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||||
|
# not support returning only a list of token IDs.
|
||||||
|
self.detokenize = detokenize
|
||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.logits_processors = logits_processors
|
self.logits_processors = logits_processors
|
||||||
@ -210,6 +216,10 @@ class SamplingParams:
|
|||||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||||
f"{self.prompt_logprobs}.")
|
f"{self.prompt_logprobs}.")
|
||||||
|
if self.stop and not self.detokenize:
|
||||||
|
raise ValueError(
|
||||||
|
"stop strings are only supported when detokenize is True. "
|
||||||
|
"Set detokenize=True to use stop.")
|
||||||
|
|
||||||
def _verify_beam_search(self) -> None:
|
def _verify_beam_search(self) -> None:
|
||||||
if self.best_of == 1:
|
if self.best_of == 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user