From 05fcd1b4308aa2dca9a1b24c540b57a94a7ba124 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 17 Apr 2025 07:45:24 -0700 Subject: [PATCH] [V1][Perf] Faster incremental detokenization (#15137) Signed-off-by: Nick Hill --- requirements/common.txt | 2 +- requirements/test.in | 1 + requirements/test.txt | 6 +- tests/lora/test_llama_tp.py | 1 + tests/tokenization/test_detokenize.py | 196 ++++++++++---- .../tool_parsers/mistral_tool_parser.py | 9 + vllm/v1/engine/detokenizer.py | 247 ++++++++++++------ 7 files changed, 317 insertions(+), 145 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 4df32460..33c4c321 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -8,7 +8,7 @@ blake3 py-cpuinfo transformers >= 4.51.1 huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. -tokenizers >= 0.19.1 # Required for Llama 3. +tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp diff --git a/requirements/test.in b/requirements/test.in index c3690f4c..833f26b5 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test transformers==4.51.1 +tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. # quantization bitsandbytes>=0.45.3 diff --git a/requirements/test.txt b/requirements/test.txt index 948c9eda..ee995402 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -624,8 +624,10 @@ tiktoken==0.7.0 # mistral-common timm==1.0.11 # via -r requirements/test.in -tokenizers==0.21.0 - # via transformers +tokenizers==0.21.1 + # via + # -r requirements/test.in + # transformers torch==2.6.0 # via # -r requirements/test.in diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index cdb8c893..e3a054bd 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256, + skip_special_tokens=False, stop=["[/assistant]"]) outputs = llm.generate( prompts, diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index b1860e0b..0f8b98a1 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -4,14 +4,22 @@ from collections.abc import Generator from typing import Any, Optional import pytest -from transformers import AutoTokenizer +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import (Detokenizer, - detokenize_incrementally) +from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import get_tokenizer_group from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, + IncrementalDetokenizer, + SlowIncrementalDetokenizer) + +SPECIAL_TOKS_TRUTH = [ + "Some text with adjacent special tokens <|padding|><|padding|>other text", # noqa +] TRUTH = [ "Hello here, this is a simple test", @@ -22,7 +30,8 @@ TRUTH = [ # incomplete UTF-8 characters # see https://github.com/vllm-project/vllm/pull/9625 "ပုံပြင်လေးပြောပြပါ်", -] +] + SPECIAL_TOKS_TRUTH + TOKENIZERS = [ "facebook/opt-125m", "gpt2", @@ -38,26 +47,37 @@ TOKENIZERS = [ ] -def _run_incremental_decode(tokenizer, all_input_ids, - skip_special_tokens: bool, starting_index: int): - decoded_text = "" - offset = 0 - token_offset = 0 - prev_tokens = None - for i in range(starting_index, len(all_input_ids)): - new_tokens, text, offset, token_offset = detokenize_incrementally( - tokenizer, - all_input_ids[:i + 1], - prev_tokens, - offset, - token_offset, - skip_special_tokens=skip_special_tokens) - decoded_text += text - if prev_tokens is None: - prev_tokens = new_tokens - else: - prev_tokens += new_tokens - return decoded_text +def _run_incremental_decode(tokenizer, + all_input_ids, + skip_special_tokens: bool, + starting_index: int, + spaces_between_special_tokens: bool = True, + fast: Optional[bool] = None): + + prompt_token_ids = all_input_ids[:starting_index] + + params = SamplingParams( + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + request = EngineCoreRequest("", "", prompt_token_ids, None, None, None, + params, None, 0.0, None) + + if fast is None: + detokenizer = IncrementalDetokenizer.from_new_request( + tokenizer, request) + elif fast: + detokenizer = FastIncrementalDetokenizer(tokenizer, request) + else: + detokenizer = SlowIncrementalDetokenizer(tokenizer, request) + + output_text = "" + for i, token_id in enumerate(all_input_ids[starting_index:]): + detokenizer.update([token_id], False) + finished = i == len(all_input_ids) - 1 + output_text += detokenizer.get_next_output_text(finished, delta=True) + + return output_text, detokenizer.output_token_ids @pytest.fixture @@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth): starting_index = 0 all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids - decoded_text = _run_incremental_decode(tokenizer, - all_input_ids, - skip_special_tokens=True, - starting_index=starting_index) + decoded_text, out_ids = _run_incremental_decode( + tokenizer, + all_input_ids, + skip_special_tokens=True, + starting_index=starting_index) assert decoded_text == truth + assert out_ids == all_input_ids[starting_index:] @pytest.fixture @@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: @pytest.mark.parametrize("with_prompt", [True, False]) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) -def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): +@pytest.mark.parametrize("spaces_between_special_tokens", (True, False)) +@pytest.mark.parametrize("fast", (True, False)) +def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, + spaces_between_special_tokens, fast): + if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): + pytest.skip() + + if skip_special_tokens and not spaces_between_special_tokens: + pytest.skip() + + if not fast and isinstance(tokenizer, PreTrainedTokenizerFast): + # Fix up inconsistency in fast/slow tokenizer behaviour. + tokenizer.add_special_tokens({ + "additional_special_tokens": [ + at for at in + tokenizer._tokenizer.get_added_tokens_decoder().values() + if at.special + ] + }) + + extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \ + else {"spaces_between_special_tokens": spaces_between_special_tokens} + + truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids + if tokenizer.bos_token_id is not None: + truth_tokens.insert(0, tokenizer.bos_token_id) + truth_tokens.append(tokenizer.eos_token_id) + + new_truth = tokenizer.decode(truth_tokens, + skip_special_tokens=skip_special_tokens, + **extra_decode_args) + if with_prompt: - truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids - prompt_input_ids = truth_tokens[:len(truth) // 2] - generated_input_ids = truth_tokens[len(truth) // 2:] + num_prompt_tokens = len( + tokenizer(truth[:len(truth) // 2], + add_special_tokens=False).input_ids) + if tokenizer.bos_token_id is not None: + num_prompt_tokens += 1 + + prompt_input_ids = truth_tokens[:num_prompt_tokens] + generated_input_ids = truth_tokens[num_prompt_tokens:] all_input_ids = prompt_input_ids + generated_input_ids starting_index = len(prompt_input_ids) prompt = tokenizer.decode(prompt_input_ids, - skip_special_tokens=skip_special_tokens) - generated = truth[len(prompt):] - else: - generated = truth - starting_index = 0 - all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids - if skip_special_tokens: - if tokenizer.bos_token_id is not None: - all_input_ids = [tokenizer.bos_token_id] + all_input_ids - starting_index += 1 - all_input_ids = all_input_ids + [tokenizer.eos_token_id] + skip_special_tokens=skip_special_tokens, + **extra_decode_args) - decoded_text = _run_incremental_decode( + generated = new_truth[len(prompt):] + else: + generated = new_truth + starting_index = 0 + all_input_ids = truth_tokens + + decoded_text, out_ids = _run_incremental_decode( tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens, - starting_index=starting_index) + starting_index=starting_index, + spaces_between_special_tokens=spaces_between_special_tokens, + fast=fast) assert decoded_text == generated + assert out_ids == all_input_ids[starting_index:] - decoded_text = _run_incremental_decode( + +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("fast", (True, False)) +def test_oov_decode(tokenizer, fast): + if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): + pytest.skip() + + decoded_text, out_ids = _run_incremental_decode( tokenizer, [len(tokenizer)], - skip_special_tokens=skip_special_tokens, - starting_index=starting_index) + skip_special_tokens=True, + starting_index=0, + spaces_between_special_tokens=True, + fast=fast) assert decoded_text == '' + assert out_ids == [len(tokenizer)] @pytest.fixture @@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: @pytest.fixture(name="complete_sequence_token_ids") def create_complete_sequence_token_ids(complete_sequence: str, tokenizer) -> list[int]: - complete_sequence_token_ids = tokenizer(complete_sequence).input_ids - return complete_sequence_token_ids + return tokenizer(complete_sequence, add_special_tokens=False).input_ids def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [1] + prompt_token_ids = prompt_token_ids or [] return Sequence( seq_id=0, - inputs=token_inputs(prompt_token_ids, prompt=""), + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, assert sequential_result == "".join(sequential_logprobs_text_chosen_token) assert sequential_result != "".join(sequential_logprobs_text_other_token) - if skip_special_tokens: + if not skip_special_tokens: # Text for logprobs for the chosen token should be the same as the # generated text. Note that this will only be true if we skip # special tokens. @@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], +def test_decode_prompt_logprobs(complete_sequence: str, + complete_sequence_token_ids: list[int], detokenizer: Detokenizer): + + # We want to use skip_special_tokens=False here but Mistral tokenizers + # don't support that. + if complete_sequence not in SPECIAL_TOKS_TRUTH: + skip_special_tokens = True + elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None), + MistralTokenizer): + skip_special_tokens = False + else: + pytest.skip("MistralTokenizers don't support " + "skip_special_tokens=False") + return """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=True, + sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, prompt_logprobs=1) # Run sequentially. @@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids tokenizer = detokenizer.get_tokenizer_for_seq(seq) - text_full = tokenizer.decode(token_ids, skip_special_tokens=True) - text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) + text_full = tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) + text_first = tokenizer.decode(token_ids[0], + skip_special_tokens=skip_special_tokens) text = text_full[len(text_first):] # Text for logprobs for the chosen token should be the same as the diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 06614456..bff6cb79 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -70,6 +70,15 @@ class MistralToolParser(ToolParser): "Mistral Tool Parser could not locate the tool call token in " "the tokenizer!") + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because mistral uses the special + # tokens to indicate the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + def extract_tool_calls( self, model_output: str, diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index bf06a175..006d53d8 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 - -from dataclasses import dataclass, field +from abc import ABC, abstractmethod from typing import Optional +from tokenizers import Tokenizer +from tokenizers.decoders import DecodeStream +from transformers import PreTrainedTokenizerFast + from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( @@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -@dataclass class IncrementalDetokenizer: - # Generation data - token_ids: list[int] - output_text: str = "" - tokens: list[str] = field(default_factory=list) - prompt_len: int = 0 - - # Stop strings - stop: list[str] = field(default_factory=list) - include_stop_str_in_output: bool = False - - # Metadata for incremental detokenization - prefix_offset: int = 0 - read_offset: int = 0 - - # Parameters for detokenization - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - - # Tokenizer for this request, - # None if detokenization is disabled. - tokenizer: Optional[AnyTokenizer] = None - - # Accounting for stop string buffering - stop_buffer_length: int = 0 - _last_output_text_offset: int = 0 + def __init__(self): + self.token_ids: list[int] = [] @property def output_token_ids(self) -> list[int]: - return self.token_ids if not self.prompt_len else ( - self.token_ids[self.prompt_len:]) + return self.token_ids + + def update(self, new_token_ids: list[int], + stop_terminated: bool) -> Optional[str]: + self.token_ids.extend(new_token_ids) + return None + + def get_next_output_text(self, finished: bool, delta: bool) -> str: + return "" @classmethod def from_new_request( @@ -54,39 +40,37 @@ class IncrementalDetokenizer: ) -> "IncrementalDetokenizer": if tokenizer is None: - return cls(token_ids=[]) + # No tokenizer => skipping detokenization. + return IncrementalDetokenizer() - tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=request.prompt_token_ids, - skip_special_tokens=request.sampling_params.skip_special_tokens, - ) + if isinstance(tokenizer, PreTrainedTokenizerFast): + # Fast tokenizer => use tokenizers library DecodeStream. + return FastIncrementalDetokenizer(tokenizer, request) + + # Fall back to slow python-based incremental detokenization. + return SlowIncrementalDetokenizer(tokenizer, request) + + +class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): + + def __init__(self, request: EngineCoreRequest): + super().__init__() + + # Stop strings + params = request.sampling_params + self.stop = stop = params.stop + self.include_stop_str_in_output = params.include_stop_str_in_output - stops = request.sampling_params.stop # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stops and not request.sampling_params.include_stop_str_in_output: - stop_buffer_length = max(len(s) for s in stops) - 1 + if stop and not self.include_stop_str_in_output: + self.stop_buffer_length = max(len(s) for s in stop) - 1 else: - stop_buffer_length = 0 + self.stop_buffer_length = 0 + self._last_output_text_offset: int = 0 - return cls( - tokens=tokens, - # Detokenizer mutates this list, so need a unique copy. - # NOTE(Nick): could we take ownership of it though? - token_ids=request.prompt_token_ids.copy(), - stop=stops, - include_stop_str_in_output=request.sampling_params. - include_stop_str_in_output, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=request.sampling_params.skip_special_tokens, - spaces_between_special_tokens=request.sampling_params. - spaces_between_special_tokens, - prompt_len=len(request.prompt_token_ids), - tokenizer=tokenizer, - stop_buffer_length=stop_buffer_length, - ) + # Generation data + self.output_text = "" def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: @@ -98,11 +82,7 @@ class IncrementalDetokenizer: Return matched stop string or None. """ if not new_token_ids: - # Skip detokenization if no new token ids - return None - if self.tokenizer is None: - # Skip detokenization if no tokenizer - self.token_ids.extend(new_token_ids) + # Skip detokenization if no new token ids. return None if stop_terminated and not self.include_stop_str_in_output: @@ -116,34 +96,16 @@ class IncrementalDetokenizer: # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. - decoded_text = "" + offset_before = len(self.output_text) for new_token_id in new_token_ids: self.token_ids.append(new_token_id) - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=self.token_ids, - prev_tokens=self.tokens, - prefix_offset=self.prefix_offset, - read_offset=self.read_offset, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self. - spaces_between_special_tokens, - ) - - self.tokens.extend(new_tokens) - self.prefix_offset = prefix_offset - self.read_offset = read_offset - - decoded_text += new_decoded_token_text - - self.output_text += decoded_text + self.output_text += self.decode_next(new_token_id) if stop_terminated: if skipped_stop_token_id is not None: - # Cleanup after skipping detokenization + # Cleanup after skipping detokenization. self.token_ids.append(skipped_stop_token_id) - # Stop token triggered; skip stop string check + # Stop token triggered; skip stop string check. return None # 2) Evaluate stop strings. @@ -151,7 +113,7 @@ class IncrementalDetokenizer: if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, - new_char_count=len(decoded_text), + new_char_count=len(self.output_text) - offset_before, stop=self.stop, include_in_output=self.include_stop_str_in_output, ) @@ -162,6 +124,10 @@ class IncrementalDetokenizer: return stop_string + @abstractmethod + def decode_next(self, next_token_id: int) -> str: + raise NotImplementedError + def get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" @@ -177,3 +143,114 @@ class IncrementalDetokenizer: self._last_output_text_offset = length return self.output_text[last_offset:length] return "" + + +class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, tokenizer: PreTrainedTokenizerFast, + request: EngineCoreRequest): + super().__init__(request) + + sampling_params = request.sampling_params + self.stream = DecodeStream( + skip_special_tokens=sampling_params.skip_special_tokens) + + self.tokenizer: Tokenizer = tokenizer._tokenizer + + # Find a safe place to start. + prompt_suffix = request.prompt_token_ids + prompt_len = len(prompt_suffix) + if prompt_len > 4: + for i in range(4, max(prompt_len + 1, 32)): + suffix = request.prompt_token_ids[-i:] + if '�' not in self.tokenizer.decode(suffix): + prompt_suffix = suffix + break + + # Prime the stream. + for tid in prompt_suffix: + self.stream.step(self.tokenizer, tid) + + self.spaces_between_special_tokens = ( + sampling_params.skip_special_tokens + or sampling_params.spaces_between_special_tokens) + + if not self.spaces_between_special_tokens: + # Store dict of added token ids so that we can suppress + # the spaces between them. + if (added_token_ids := getattr(self.tokenizer, "added_token_ids", + None)) is None: + self.tokenizer.added_token_ids = added_token_ids = { + tid: tok.content + for tid, tok in + self.tokenizer.get_added_tokens_decoder().items() + } + + if added_token_ids: + self.last_special = False + self.added_token_ids = added_token_ids + else: + # No added tokens. + self.spaces_between_special_tokens = True + + def decode_next(self, next_token_id: int) -> str: + token = self.stream.step(self.tokenizer, next_token_id) + + if not self.spaces_between_special_tokens: + special_token = self.added_token_ids.get(next_token_id) + is_special = special_token is not None + if is_special and self.last_special: + # Return raw token string without any prefixed spaces. + token = special_token + self.last_special = is_special + + return token or "" + + +class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): + super().__init__(request) + + self.tokenizer = tokenizer + + # Metadata for incremental detokenization. + self.tokens, self.prefix_offset, self.read_offset = ( + convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=request.sampling_params. + skip_special_tokens, + )) + + self.token_ids.extend(request.prompt_token_ids) + self.prompt_len = len(request.prompt_token_ids) + + params = request.sampling_params + self.skip_special_tokens = params.skip_special_tokens + self.spaces_between_special_tokens = ( + params.spaces_between_special_tokens) + + @property + def output_token_ids(self) -> list[int]: + return self.token_ids if not self.prompt_len else ( + self.token_ids[self.prompt_len:]) + + def decode_next(self, next_token_id: int) -> str: + new_tokens, decoded_text, prefix_offset, read_offset = ( + detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self. + spaces_between_special_tokens, + )) + + self.tokens.extend(new_tokens) + self.prefix_offset = prefix_offset + self.read_offset = read_offset + + return decoded_text