diff --git a/tests/engine/test_detokenization.py b/tests/engine/test_detokenization.py new file mode 100644 index 00000000..f77f6d07 --- /dev/null +++ b/tests/engine/test_detokenization.py @@ -0,0 +1,32 @@ +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_computed_prefix_blocks(model: str): + # This test checks if the engine generates completions both with and + # without optional detokenization, that detokenization includes text + # and no-detokenization doesn't, and that both completions have the same + # token_ids. + prompt = ( + "You are a helpful assistant. How do I build a car from cardboard and " + "paper clips? Is there an easy to follow video tutorial available " + "online for free?") + + llm = LLM(model=model) + sampling_params = SamplingParams(max_tokens=10, + temperature=0.0, + detokenize=False) + + outputs_no_detokenization = llm.generate(prompt, + sampling_params)[0].outputs[0] + sampling_params.detokenize = True + outputs_with_detokenization = llm.generate(prompt, + sampling_params)[0].outputs[0] + + assert outputs_no_detokenization.text == '' + assert outputs_with_detokenization.text != '' + assert outputs_no_detokenization.token_ids == \ + outputs_with_detokenization.token_ids diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5c343921..c22585a3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -432,7 +432,7 @@ class LLMEngine: # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None: + if prompt_logprobs is not None and seq_group.sampling_params.detokenize: self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) seq_group.prompt_logprobs = prompt_logprobs @@ -478,8 +478,9 @@ class LLMEngine: child_seqs.append((parent, parent)) for seq, _ in child_seqs: - self.detokenizer.decode_sequence_inplace(seq, - seq_group.sampling_params) + if seq_group.sampling_params.detokenize: + self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -791,12 +792,13 @@ class LLMEngine: if seq.get_output_len() < sampling_params.min_tokens: return - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return + if sampling_params.detokenize: + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 6f81ee31..bbba02a8 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -88,6 +88,7 @@ class SamplingParams: log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. prompt_logprobs: Number of log probabilities to return per prompt token. + detokenize: Whether to detokenize the output. Defaults to True. skip_special_tokens: Whether to skip special tokens in the output. spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. @@ -118,6 +119,7 @@ class SamplingParams: min_tokens: int = 0, logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, + detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, @@ -150,6 +152,10 @@ class SamplingParams: self.min_tokens = min_tokens self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. + self.detokenize = detokenize self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors @@ -210,6 +216,10 @@ class SamplingParams: if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") + if self.stop and not self.detokenize: + raise ValueError( + "stop strings are only supported when detokenize is True. " + "Set detokenize=True to use stop.") def _verify_beam_search(self) -> None: if self.best_of == 1: