[V1][Perf] Faster incremental detokenization (#15137)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-17 07:45:24 -07:00 committed by GitHub
parent 7c02d6a137
commit 05fcd1b430
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 317 additions and 145 deletions

View File

@ -8,7 +8,7 @@ blake3
py-cpuinfo py-cpuinfo
transformers >= 4.51.1 transformers >= 4.51.1
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer. protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp aiohttp

View File

@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.51.1 transformers==4.51.1
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization # quantization
bitsandbytes>=0.45.3 bitsandbytes>=0.45.3

View File

@ -624,8 +624,10 @@ tiktoken==0.7.0
# mistral-common # mistral-common
timm==1.0.11 timm==1.0.11
# via -r requirements/test.in # via -r requirements/test.in
tokenizers==0.21.0 tokenizers==0.21.1
# via transformers # via
# -r requirements/test.in
# transformers
torch==2.6.0 torch==2.6.0
# via # via
# -r requirements/test.in # -r requirements/test.in

View File

@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
] ]
sampling_params = vllm.SamplingParams(temperature=0, sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256, max_tokens=256,
skip_special_tokens=False,
stop=["[/assistant]"]) stop=["[/assistant]"])
outputs = llm.generate( outputs = llm.generate(
prompts, prompts,

View File

@ -4,14 +4,22 @@ from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer, from vllm.transformers_utils.detokenizer import Detokenizer
detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
]
TRUTH = [ TRUTH = [
"Hello here, this is a simple test", "Hello here, this is a simple test",
@ -22,7 +30,8 @@ TRUTH = [
# incomplete UTF-8 characters # incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625 # see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်", "ပုံပြင်လေးပြောပြပါ်",
] ] + SPECIAL_TOKS_TRUTH
TOKENIZERS = [ TOKENIZERS = [
"facebook/opt-125m", "facebook/opt-125m",
"gpt2", "gpt2",
@ -38,26 +47,37 @@ TOKENIZERS = [
] ]
def _run_incremental_decode(tokenizer, all_input_ids, def _run_incremental_decode(tokenizer,
skip_special_tokens: bool, starting_index: int): all_input_ids,
decoded_text = "" skip_special_tokens: bool,
offset = 0 starting_index: int,
token_offset = 0 spaces_between_special_tokens: bool = True,
prev_tokens = None fast: Optional[bool] = None):
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally( prompt_token_ids = all_input_ids[:starting_index]
tokenizer,
all_input_ids[:i + 1], params = SamplingParams(
prev_tokens, skip_special_tokens=skip_special_tokens,
offset, spaces_between_special_tokens=spaces_between_special_tokens,
token_offset, )
skip_special_tokens=skip_special_tokens) request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
decoded_text += text params, None, 0.0, None)
if prev_tokens is None:
prev_tokens = new_tokens if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else: else:
prev_tokens += new_tokens detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
return decoded_text
output_text = ""
for i, token_id in enumerate(all_input_ids[starting_index:]):
detokenizer.update([token_id], False)
finished = i == len(all_input_ids) - 1
output_text += detokenizer.get_next_output_text(finished, delta=True)
return output_text, detokenizer.output_token_ids
@pytest.fixture @pytest.fixture
@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index = 0 starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
decoded_text = _run_incremental_decode(tokenizer, decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids, all_input_ids,
skip_special_tokens=True, skip_special_tokens=True,
starting_index=starting_index) starting_index=starting_index)
assert decoded_text == truth assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@pytest.fixture @pytest.fixture
@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("with_prompt", [True, False]) @pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): @pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
if with_prompt: @pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
if skip_special_tokens and not spaces_between_special_tokens:
pytest.skip()
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
else {"spaces_between_special_tokens": spaces_between_special_tokens}
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
prompt_input_ids = truth_tokens[:len(truth) // 2] if tokenizer.bos_token_id is not None:
generated_input_ids = truth_tokens[len(truth) // 2:] truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
if with_prompt:
num_prompt_tokens = len(
tokenizer(truth[:len(truth) // 2],
add_special_tokens=False).input_ids)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
prompt_input_ids = truth_tokens[:num_prompt_tokens]
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids) starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids, prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens,
generated = truth[len(prompt):] **extra_decode_args)
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
decoded_text = _run_incremental_decode( generated = new_truth[len(prompt):]
else:
generated = new_truth
starting_index = 0
all_input_ids = truth_tokens
decoded_text, out_ids = _run_incremental_decode(
tokenizer, tokenizer,
all_input_ids, all_input_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
starting_index=starting_index) starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
assert decoded_text == generated assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
decoded_text = _run_incremental_decode(
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)], tokenizer, [len(tokenizer)],
skip_special_tokens=skip_special_tokens, skip_special_tokens=True,
starting_index=starting_index) starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
assert decoded_text == '' assert decoded_text == ''
assert out_ids == [len(tokenizer)]
@pytest.fixture @pytest.fixture
@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@pytest.fixture(name="complete_sequence_token_ids") @pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str, def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]: tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids return tokenizer(complete_sequence, add_special_tokens=False).input_ids
return complete_sequence_token_ids
def create_sequence(prompt_token_ids=None): def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1] prompt_token_ids = prompt_token_ids or []
return Sequence( return Sequence(
seq_id=0, seq_id=0,
inputs=token_inputs(prompt_token_ids, prompt="<s>"), inputs=token_inputs(prompt_token_ids),
block_size=16, block_size=16,
) )
@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert sequential_result == "".join(sequential_logprobs_text_chosen_token) assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token) assert sequential_result != "".join(sequential_logprobs_text_other_token)
if skip_special_tokens: if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the # Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip # generated text. Note that this will only be true if we skip
# special tokens. # special tokens.
@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer): detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly.""" """Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True, sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1) prompt_logprobs=1)
# Run sequentially. # Run sequentially.
@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token. # decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids token_ids = complete_sequence_token_ids
tokenizer = detokenizer.get_tokenizer_for_seq(seq) tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True) text_full = tokenizer.decode(token_ids,
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):] text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the # Text for logprobs for the chosen token should be the same as the

View File

@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
"Mistral Tool Parser could not locate the tool call token in " "Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!") "the tokenizer!")
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because mistral uses the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def extract_tool_calls( def extract_tool_calls(
self, self,
model_output: str, model_output: str,

View File

@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import ( from vllm.transformers_utils.detokenizer_utils import (
@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class IncrementalDetokenizer: class IncrementalDetokenizer:
# Generation data def __init__(self):
token_ids: list[int] self.token_ids: list[int] = []
output_text: str = ""
tokens: list[str] = field(default_factory=list)
prompt_len: int = 0
# Stop strings
stop: list[str] = field(default_factory=list)
include_stop_str_in_output: bool = False
# Metadata for incremental detokenization
prefix_offset: int = 0
read_offset: int = 0
# Parameters for detokenization
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer] = None
# Accounting for stop string buffering
stop_buffer_length: int = 0
_last_output_text_offset: int = 0
@property @property
def output_token_ids(self) -> list[int]: def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else ( return self.token_ids
self.token_ids[self.prompt_len:])
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
self.token_ids.extend(new_token_ids)
return None
def get_next_output_text(self, finished: bool, delta: bool) -> str:
return ""
@classmethod @classmethod
def from_new_request( def from_new_request(
@ -54,39 +40,37 @@ class IncrementalDetokenizer:
) -> "IncrementalDetokenizer": ) -> "IncrementalDetokenizer":
if tokenizer is None: if tokenizer is None:
return cls(token_ids=[]) # No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer=tokenizer, # Fast tokenizer => use tokenizers library DecodeStream.
prompt_ids=request.prompt_token_ids, return FastIncrementalDetokenizer(tokenizer, request)
skip_special_tokens=request.sampling_params.skip_special_tokens,
) # Fall back to slow python-based incremental detokenization.
return SlowIncrementalDetokenizer(tokenizer, request)
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
def __init__(self, request: EngineCoreRequest):
super().__init__()
# Stop strings
params = request.sampling_params
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
stops = request.sampling_params.stop
# Number of chars to hold back when stop strings are to be excluded # Number of chars to hold back when stop strings are to be excluded
# from streamed output. # from streamed output.
if stops and not request.sampling_params.include_stop_str_in_output: if stop and not self.include_stop_str_in_output:
stop_buffer_length = max(len(s) for s in stops) - 1 self.stop_buffer_length = max(len(s) for s in stop) - 1
else: else:
stop_buffer_length = 0 self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
return cls( # Generation data
tokens=tokens, self.output_text = ""
# Detokenizer mutates this list, so need a unique copy.
# NOTE(Nick): could we take ownership of it though?
token_ids=request.prompt_token_ids.copy(),
stop=stops,
include_stop_str_in_output=request.sampling_params.
include_stop_str_in_output,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=request.sampling_params.
spaces_between_special_tokens,
prompt_len=len(request.prompt_token_ids),
tokenizer=tokenizer,
stop_buffer_length=stop_buffer_length,
)
def update(self, new_token_ids: list[int], def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]: stop_terminated: bool) -> Optional[str]:
@ -98,11 +82,7 @@ class IncrementalDetokenizer:
Return matched stop string or None. Return matched stop string or None.
""" """
if not new_token_ids: if not new_token_ids:
# Skip detokenization if no new token ids # Skip detokenization if no new token ids.
return None
if self.tokenizer is None:
# Skip detokenization if no tokenizer
self.token_ids.extend(new_token_ids)
return None return None
if stop_terminated and not self.include_stop_str_in_output: if stop_terminated and not self.include_stop_str_in_output:
@ -116,34 +96,16 @@ class IncrementalDetokenizer:
# 1) Detokenize the new token ids incrementally. # 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of # TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this. # new_token_ids is more than 1. We need to optimize this.
decoded_text = "" offset_before = len(self.output_text)
for new_token_id in new_token_ids: for new_token_id in new_token_ids:
self.token_ids.append(new_token_id) self.token_ids.append(new_token_id)
(new_tokens, new_decoded_token_text, prefix_offset, self.output_text += self.decode_next(new_token_id)
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
)
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
decoded_text += new_decoded_token_text
self.output_text += decoded_text
if stop_terminated: if stop_terminated:
if skipped_stop_token_id is not None: if skipped_stop_token_id is not None:
# Cleanup after skipping detokenization # Cleanup after skipping detokenization.
self.token_ids.append(skipped_stop_token_id) self.token_ids.append(skipped_stop_token_id)
# Stop token triggered; skip stop string check # Stop token triggered; skip stop string check.
return None return None
# 2) Evaluate stop strings. # 2) Evaluate stop strings.
@ -151,7 +113,7 @@ class IncrementalDetokenizer:
if self.stop: if self.stop:
stop = StopChecker.check_stop_strings( stop = StopChecker.check_stop_strings(
output_text=self.output_text, output_text=self.output_text,
new_char_count=len(decoded_text), new_char_count=len(self.output_text) - offset_before,
stop=self.stop, stop=self.stop,
include_in_output=self.include_stop_str_in_output, include_in_output=self.include_stop_str_in_output,
) )
@ -162,6 +124,10 @@ class IncrementalDetokenizer:
return stop_string return stop_string
@abstractmethod
def decode_next(self, next_token_id: int) -> str:
raise NotImplementedError
def get_next_output_text(self, finished: bool, delta: bool) -> str: def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to """If delta is True, only new text since the last call to
this method is returned""" this method is returned"""
@ -177,3 +143,114 @@ class IncrementalDetokenizer:
self._last_output_text_offset = length self._last_output_text_offset = length
return self.output_text[last_offset:length] return self.output_text[last_offset:length]
return "" return ""
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: PreTrainedTokenizerFast,
request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
self.stream = DecodeStream(
skip_special_tokens=sampling_params.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start.
prompt_suffix = request.prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, max(prompt_len + 1, 32)):
suffix = request.prompt_token_ids[-i:]
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self.stream.step(self.tokenizer, tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
or sampling_params.spaces_between_special_tokens)
if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
None)) is None:
self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content
for tid, tok in
self.tokenizer.get_added_tokens_decoder().items()
}
if added_token_ids:
self.last_special = False
self.added_token_ids = added_token_ids
else:
# No added tokens.
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self.stream.step(self.tokenizer, next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
is_special = special_token is not None
if is_special and self.last_special:
# Return raw token string without any prefixed spaces.
token = special_token
self.last_special = is_special
return token or ""
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.
skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)
@property
def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
def decode_next(self, next_token_id: int) -> str:
new_tokens, decoded_text, prefix_offset, read_offset = (
detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
))
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
return decoded_text