[V1][Perf] Faster incremental detokenization (#15137)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
7c02d6a137
commit
05fcd1b430
@ -8,7 +8,7 @@ blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.51.1
|
||||
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
|
@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
transformers==4.51.1
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
|
||||
# quantization
|
||||
bitsandbytes>=0.45.3
|
||||
|
@ -624,8 +624,10 @@ tiktoken==0.7.0
|
||||
# mistral-common
|
||||
timm==1.0.11
|
||||
# via -r requirements/test.in
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
tokenizers==0.21.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# transformers
|
||||
torch==2.6.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
|
@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
skip_special_tokens=False,
|
||||
stop=["[/assistant]"])
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
|
@ -4,14 +4,22 @@ from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
||||
detokenize_incrementally)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
IncrementalDetokenizer,
|
||||
SlowIncrementalDetokenizer)
|
||||
|
||||
SPECIAL_TOKS_TRUTH = [
|
||||
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
|
||||
]
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test",
|
||||
@ -22,7 +30,8 @@ TRUTH = [
|
||||
# incomplete UTF-8 characters
|
||||
# see https://github.com/vllm-project/vllm/pull/9625
|
||||
"ပုံပြင်လေးပြောပြပါ်",
|
||||
]
|
||||
] + SPECIAL_TOKS_TRUTH
|
||||
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
@ -38,26 +47,37 @@ TOKENIZERS = [
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool, starting_index: int):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(starting_index, len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
def _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens: bool,
|
||||
starting_index: int,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
fast: Optional[bool] = None):
|
||||
|
||||
prompt_token_ids = all_input_ids[:starting_index]
|
||||
|
||||
params = SamplingParams(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
|
||||
params, None, 0.0, None)
|
||||
|
||||
if fast is None:
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer, request)
|
||||
elif fast:
|
||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||
else:
|
||||
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
output_text = ""
|
||||
for i, token_id in enumerate(all_input_ids[starting_index:]):
|
||||
detokenizer.update([token_id], False)
|
||||
finished = i == len(all_input_ids) - 1
|
||||
output_text += detokenizer.get_next_output_text(finished, delta=True)
|
||||
|
||||
return output_text, detokenizer.output_token_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
|
||||
decoded_text = _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
assert decoded_text == truth
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
|
||||
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
|
||||
@pytest.mark.parametrize("fast", (True, False))
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
spaces_between_special_tokens, fast):
|
||||
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
pytest.skip()
|
||||
|
||||
if skip_special_tokens and not spaces_between_special_tokens:
|
||||
pytest.skip()
|
||||
|
||||
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fix up inconsistency in fast/slow tokenizer behaviour.
|
||||
tokenizer.add_special_tokens({
|
||||
"additional_special_tokens": [
|
||||
at for at in
|
||||
tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||
if at.special
|
||||
]
|
||||
})
|
||||
|
||||
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
|
||||
else {"spaces_between_special_tokens": spaces_between_special_tokens}
|
||||
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if tokenizer.bos_token_id is not None:
|
||||
truth_tokens.insert(0, tokenizer.bos_token_id)
|
||||
truth_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
new_truth = tokenizer.decode(truth_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
|
||||
if with_prompt:
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||
num_prompt_tokens = len(
|
||||
tokenizer(truth[:len(truth) // 2],
|
||||
add_special_tokens=False).input_ids)
|
||||
if tokenizer.bos_token_id is not None:
|
||||
num_prompt_tokens += 1
|
||||
|
||||
prompt_input_ids = truth_tokens[:num_prompt_tokens]
|
||||
generated_input_ids = truth_tokens[num_prompt_tokens:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
starting_index = len(prompt_input_ids)
|
||||
prompt = tokenizer.decode(prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
generated = truth[len(prompt):]
|
||||
else:
|
||||
generated = truth
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if skip_special_tokens:
|
||||
if tokenizer.bos_token_id is not None:
|
||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||
starting_index += 1
|
||||
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
generated = new_truth[len(prompt):]
|
||||
else:
|
||||
generated = new_truth
|
||||
starting_index = 0
|
||||
all_input_ids = truth_tokens
|
||||
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
starting_index=starting_index,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
fast=fast)
|
||||
|
||||
assert decoded_text == generated
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("fast", (True, False))
|
||||
def test_oov_decode(tokenizer, fast):
|
||||
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
pytest.skip()
|
||||
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer, [len(tokenizer)],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
skip_special_tokens=True,
|
||||
starting_index=0,
|
||||
spaces_between_special_tokens=True,
|
||||
fast=fast)
|
||||
|
||||
assert decoded_text == ''
|
||||
assert out_ids == [len(tokenizer)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer) -> list[int]:
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
||||
return complete_sequence_token_ids
|
||||
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
|
||||
|
||||
|
||||
def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or [1]
|
||||
prompt_token_ids = prompt_token_ids or []
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
|
||||
inputs=token_inputs(prompt_token_ids),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||
|
||||
if skip_special_tokens:
|
||||
if not skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# generated text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: list[int],
|
||||
detokenizer: Detokenizer):
|
||||
|
||||
# We want to use skip_special_tokens=False here but Mistral tokenizers
|
||||
# don't support that.
|
||||
if complete_sequence not in SPECIAL_TOKS_TRUTH:
|
||||
skip_special_tokens = True
|
||||
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
|
||||
MistralTokenizer):
|
||||
skip_special_tokens = False
|
||||
else:
|
||||
pytest.skip("MistralTokenizers don't support "
|
||||
"skip_special_tokens=False")
|
||||
return
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
|
||||
text_full = tokenizer.decode(token_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text_first = tokenizer.decode(token_ids[0],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
|
@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!")
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because mistral uses the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
|
@ -1,8 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
# Generation data
|
||||
token_ids: list[int]
|
||||
output_text: str = ""
|
||||
tokens: list[str] = field(default_factory=list)
|
||||
prompt_len: int = 0
|
||||
|
||||
# Stop strings
|
||||
stop: list[str] = field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
|
||||
# Metadata for incremental detokenization
|
||||
prefix_offset: int = 0
|
||||
read_offset: int = 0
|
||||
|
||||
# Parameters for detokenization
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
|
||||
# Accounting for stop string buffering
|
||||
stop_buffer_length: int = 0
|
||||
_last_output_text_offset: int = 0
|
||||
def __init__(self):
|
||||
self.token_ids: list[int] = []
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids if not self.prompt_len else (
|
||||
self.token_ids[self.prompt_len:])
|
||||
return self.token_ids
|
||||
|
||||
def update(self, new_token_ids: list[int],
|
||||
stop_terminated: bool) -> Optional[str]:
|
||||
self.token_ids.extend(new_token_ids)
|
||||
return None
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
@ -54,39 +40,37 @@ class IncrementalDetokenizer:
|
||||
) -> "IncrementalDetokenizer":
|
||||
|
||||
if tokenizer is None:
|
||||
return cls(token_ids=[])
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
|
||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
||||
)
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||
return FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
# Fall back to slow python-based incremental detokenization.
|
||||
return SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
|
||||
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__()
|
||||
|
||||
# Stop strings
|
||||
params = request.sampling_params
|
||||
self.stop = stop = params.stop
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
stops = request.sampling_params.stop
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
# from streamed output.
|
||||
if stops and not request.sampling_params.include_stop_str_in_output:
|
||||
stop_buffer_length = max(len(s) for s in stops) - 1
|
||||
if stop and not self.include_stop_str_in_output:
|
||||
self.stop_buffer_length = max(len(s) for s in stop) - 1
|
||||
else:
|
||||
stop_buffer_length = 0
|
||||
self.stop_buffer_length = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
return cls(
|
||||
tokens=tokens,
|
||||
# Detokenizer mutates this list, so need a unique copy.
|
||||
# NOTE(Nick): could we take ownership of it though?
|
||||
token_ids=request.prompt_token_ids.copy(),
|
||||
stop=stops,
|
||||
include_stop_str_in_output=request.sampling_params.
|
||||
include_stop_str_in_output,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=request.sampling_params.
|
||||
spaces_between_special_tokens,
|
||||
prompt_len=len(request.prompt_token_ids),
|
||||
tokenizer=tokenizer,
|
||||
stop_buffer_length=stop_buffer_length,
|
||||
)
|
||||
# Generation data
|
||||
self.output_text = ""
|
||||
|
||||
def update(self, new_token_ids: list[int],
|
||||
stop_terminated: bool) -> Optional[str]:
|
||||
@ -98,11 +82,7 @@ class IncrementalDetokenizer:
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
if not new_token_ids:
|
||||
# Skip detokenization if no new token ids
|
||||
return None
|
||||
if self.tokenizer is None:
|
||||
# Skip detokenization if no tokenizer
|
||||
self.token_ids.extend(new_token_ids)
|
||||
# Skip detokenization if no new token ids.
|
||||
return None
|
||||
|
||||
if stop_terminated and not self.include_stop_str_in_output:
|
||||
@ -116,34 +96,16 @@ class IncrementalDetokenizer:
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
decoded_text = ""
|
||||
offset_before = len(self.output_text)
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
|
||||
decoded_text += new_decoded_token_text
|
||||
|
||||
self.output_text += decoded_text
|
||||
self.output_text += self.decode_next(new_token_id)
|
||||
|
||||
if stop_terminated:
|
||||
if skipped_stop_token_id is not None:
|
||||
# Cleanup after skipping detokenization
|
||||
# Cleanup after skipping detokenization.
|
||||
self.token_ids.append(skipped_stop_token_id)
|
||||
# Stop token triggered; skip stop string check
|
||||
# Stop token triggered; skip stop string check.
|
||||
return None
|
||||
|
||||
# 2) Evaluate stop strings.
|
||||
@ -151,7 +113,7 @@ class IncrementalDetokenizer:
|
||||
if self.stop:
|
||||
stop = StopChecker.check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(decoded_text),
|
||||
new_char_count=len(self.output_text) - offset_before,
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
@ -162,6 +124,10 @@ class IncrementalDetokenizer:
|
||||
|
||||
return stop_string
|
||||
|
||||
@abstractmethod
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
@ -177,3 +143,114 @@ class IncrementalDetokenizer:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
|
||||
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerFast,
|
||||
request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
self.stream = DecodeStream(
|
||||
skip_special_tokens=sampling_params.skip_special_tokens)
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Find a safe place to start.
|
||||
prompt_suffix = request.prompt_token_ids
|
||||
prompt_len = len(prompt_suffix)
|
||||
if prompt_len > 4:
|
||||
for i in range(4, max(prompt_len + 1, 32)):
|
||||
suffix = request.prompt_token_ids[-i:]
|
||||
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
|
||||
prompt_suffix = suffix
|
||||
break
|
||||
|
||||
# Prime the stream.
|
||||
for tid in prompt_suffix:
|
||||
self.stream.step(self.tokenizer, tid)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
or sampling_params.spaces_between_special_tokens)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
# Store dict of added token ids so that we can suppress
|
||||
# the spaces between them.
|
||||
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
|
||||
None)) is None:
|
||||
self.tokenizer.added_token_ids = added_token_ids = {
|
||||
tid: tok.content
|
||||
for tid, tok in
|
||||
self.tokenizer.get_added_tokens_decoder().items()
|
||||
}
|
||||
|
||||
if added_token_ids:
|
||||
self.last_special = False
|
||||
self.added_token_ids = added_token_ids
|
||||
else:
|
||||
# No added tokens.
|
||||
self.spaces_between_special_tokens = True
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
special_token = self.added_token_ids.get(next_token_id)
|
||||
is_special = special_token is not None
|
||||
if is_special and self.last_special:
|
||||
# Return raw token string without any prefixed spaces.
|
||||
token = special_token
|
||||
self.last_special = is_special
|
||||
|
||||
return token or ""
|
||||
|
||||
|
||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=request.sampling_params.
|
||||
skip_special_tokens,
|
||||
))
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids)
|
||||
self.prompt_len = len(request.prompt_token_ids)
|
||||
|
||||
params = request.sampling_params
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = (
|
||||
params.spaces_between_special_tokens)
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids if not self.prompt_len else (
|
||||
self.token_ids[self.prompt_len:])
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
new_tokens, decoded_text, prefix_offset, read_offset = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.
|
||||
spaces_between_special_tokens,
|
||||
))
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
|
||||
return decoded_text
|
||||
|
Loading…
x
Reference in New Issue
Block a user