[Fix][Structured Output] using vocab_size to construct matcher (#14868)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
aaaec52ad9
commit
c0efdd655b
@ -200,6 +200,7 @@ steps:
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
|
@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
outlines == 0.1.11
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.15; platform_machine == "x86_64" or platform_machine == "aarch64"
|
||||
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
@ -208,8 +209,6 @@ def test_guided_decoding_backend_options():
|
||||
|
||||
|
||||
def test_pickle_xgrammar_tokenizer_data():
|
||||
|
||||
# TODO: move to another test file for xgrammar
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
except ImportError:
|
||||
@ -217,7 +216,11 @@ def test_pickle_xgrammar_tokenizer_data():
|
||||
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
TokenizerData)
|
||||
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
|
||||
tokenizer_data = TokenizerData(
|
||||
metadata=
|
||||
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
|
||||
encoded_vocab=['!', '"', '#', '$', '%'],
|
||||
)
|
||||
pickled = pickle.dumps(tokenizer_data)
|
||||
|
||||
assert pickled is not None
|
||||
@ -225,4 +228,5 @@ def test_pickle_xgrammar_tokenizer_data():
|
||||
depickled: TokenizerData = pickle.loads(pickled)
|
||||
|
||||
assert depickled is not None
|
||||
assert depickled.vocab_type == xgr.VocabType.RAW
|
||||
assert json.loads(
|
||||
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value
|
||||
|
@ -18,9 +18,6 @@ MODELS_TO_TEST = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||
]
|
||||
|
||||
# Undo after https://github.com/vllm-project/vllm/pull/14868
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
|
@ -9,7 +9,6 @@ from vllm.model_executor.guided_decoding.reasoner import get_reasoner
|
||||
from vllm.model_executor.guided_decoding.utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||
from vllm.platforms import CpuArchEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -26,7 +25,7 @@ def maybe_backend_fallback(
|
||||
|
||||
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
|
||||
fallback: str) -> None:
|
||||
"""Change the backend to the specified fallback with a warning log,
|
||||
"""Change the backend to the specified fallback with a warning log,
|
||||
or raise a ValueError if the `no-fallback` option is specified."""
|
||||
if guided_params.no_fallback():
|
||||
raise ValueError(message)
|
||||
@ -53,19 +52,12 @@ def maybe_backend_fallback(
|
||||
if guided_params.backend_name == "xgrammar":
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
xgr_installed)
|
||||
# xgrammar only has x86 wheels for linux, fallback to outlines
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
||||
fallback_or_error(guided_params,
|
||||
"xgrammar is only supported on x86 CPUs.",
|
||||
"outlines")
|
||||
|
||||
# xgrammar doesn't support regex, fallback to outlines
|
||||
if guided_params.regex is not None:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support regex guided decoding.", "outlines")
|
||||
|
||||
# xgrammar doesn't support some JSON schema features
|
||||
elif (guided_params.json is not None
|
||||
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||
|
@ -9,13 +9,11 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
from xgrammar.base import _core as xgr_core
|
||||
xgr_installed = True
|
||||
except ImportError:
|
||||
xgr_installed = False
|
||||
@ -35,7 +33,6 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: passing batch size to max threads here
|
||||
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
@ -52,18 +49,8 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
@dataclass(frozen=True)
|
||||
class TokenizerData:
|
||||
"""Immutable container for cached tokenizer data."""
|
||||
metadata: str
|
||||
encoded_vocab: list[str] = field(default_factory=list)
|
||||
stop_token_ids: list[int] | None = None
|
||||
# These fields are mutually exclusive: `backend_str` is used to create a
|
||||
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
|
||||
# used within the constructor of TokenizeInfo
|
||||
backend_str: str | None = None
|
||||
vocab_type: xgr.VocabType | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for mutual exclusive
|
||||
assert not (self.backend_str and self.vocab_type), \
|
||||
"backend_str and vocab_type are mutual exclusive"
|
||||
|
||||
|
||||
class TokenizerDataCache:
|
||||
@ -71,46 +58,52 @@ class TokenizerDataCache:
|
||||
_cache: dict[int, TokenizerData] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer_data(cls,
|
||||
tokenizer: PreTrainedTokenizer) -> TokenizerData:
|
||||
tokenizer_hash = hash(tokenizer)
|
||||
def get_tokenizer_data(
|
||||
cls,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
/,
|
||||
*,
|
||||
tokenizer_hash: int,
|
||||
vocab_size: int,
|
||||
) -> TokenizerData:
|
||||
|
||||
if tokenizer_hash not in cls._cache:
|
||||
# Vendored from xgrammar logic since we cannot pickle the tokenizer
|
||||
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
# NOTE: We will need to use lm_head's vocab_size
|
||||
# to determine correct special_token_ids for this tokenizer.
|
||||
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
metadata = json.loads(tokenizer_info.dump_metadata())
|
||||
|
||||
# Vendored from xgrammar logic to get encoded_vocab
|
||||
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
|
||||
try:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1])
|
||||
]
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
|
||||
stop_token_ids = None
|
||||
backend_str = ""
|
||||
vocab_type = xgr.VocabType.RAW
|
||||
# maintain tokenizer's indexing
|
||||
encoded_vocab = [""] * tokenizer_info.vocab_size
|
||||
for token, idx in vocab_dict.items():
|
||||
if idx < tokenizer_info.vocab_size:
|
||||
encoded_vocab[idx] = token
|
||||
|
||||
if stop_token_ids is None and hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id") and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
backend_str = tokenizer.backend_tokenizer.to_str()
|
||||
vocab_type = None
|
||||
|
||||
elif isinstance(tokenizer, MistralTokenizer):
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type = xgr.VocabType.BYTE_FALLBACK
|
||||
metadata.update({
|
||||
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
|
||||
"add_prefix_space": True
|
||||
})
|
||||
|
||||
cls._cache[tokenizer_hash] = TokenizerData(
|
||||
encoded_vocab=encoded_vocab,
|
||||
stop_token_ids=stop_token_ids,
|
||||
backend_str=backend_str,
|
||||
vocab_type=vocab_type)
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
return cls._cache[tokenizer_hash]
|
||||
|
||||
@ -129,30 +122,15 @@ class GrammarCompilerCache:
|
||||
cache_key = str(config.tokenizer_hash)
|
||||
|
||||
if cache_key not in cls._cache:
|
||||
assert config.tokenizer_data is not None
|
||||
assert config.tokenizer_data.encoded_vocab is not None
|
||||
|
||||
config_data = config.tokenizer_data
|
||||
|
||||
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
||||
# tokenizer_data is created and cached. This data is used to build
|
||||
# a tokenizer_info and create an xgrammar compiler.
|
||||
# - If tokenizer_data has backend_str set, use
|
||||
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
|
||||
# - Otherwise, use the default constructor with vocab_type.
|
||||
# - xgr_core.TokenizerInfo.from_huggingface !=
|
||||
# xgr.TokenizerInfo.from_huggingface.
|
||||
if config_data.backend_str:
|
||||
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
||||
xgr_core.TokenizerInfo.from_huggingface(
|
||||
config_data.encoded_vocab, config_data.backend_str,
|
||||
config.vocab_size, config_data.stop_token_ids))
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo(
|
||||
config_data.encoded_vocab,
|
||||
config_data.vocab_type,
|
||||
vocab_size=config.vocab_size,
|
||||
stop_token_ids=config_data.stop_token_ids)
|
||||
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||
encoded_vocab=config_data.encoded_vocab,
|
||||
metadata=config_data.metadata,
|
||||
)
|
||||
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||
tokenizer_info, max_threads=config.max_threads)
|
||||
|
||||
@ -163,13 +141,12 @@ class GrammarCompilerCache:
|
||||
class GrammarConfig:
|
||||
"""Serializable configuration for grammar compilation"""
|
||||
tokenizer_hash: int
|
||||
vocab_size: int
|
||||
tokenizer_data: TokenizerData
|
||||
json_str: str | None = None
|
||||
grammar_str: str | None = None
|
||||
json_object: bool | None = None
|
||||
any_whitespace: bool = True
|
||||
max_threads: int = 8
|
||||
tokenizer_data: TokenizerData | None = None
|
||||
|
||||
@classmethod
|
||||
def from_guided_params(cls,
|
||||
@ -179,7 +156,11 @@ class GrammarConfig:
|
||||
max_threads: int = 8) -> GrammarConfig:
|
||||
|
||||
tokenizer_hash = hash(tokenizer)
|
||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
|
||||
tokenizer,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
)
|
||||
|
||||
if guided_params.json:
|
||||
if not isinstance(guided_params.json, str):
|
||||
@ -218,7 +199,6 @@ class GrammarConfig:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(json_str=json_str,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
@ -246,14 +226,12 @@ class GrammarConfig:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(grammar_str=grammar_str,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data)
|
||||
elif guided_params.json_object:
|
||||
return cls(
|
||||
json_object=True,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
@ -267,7 +245,6 @@ class GrammarConfig:
|
||||
|
||||
return cls(
|
||||
grammar_str=choice_str,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
@ -291,6 +268,13 @@ class GrammarConfig:
|
||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
return grammar
|
||||
|
||||
@staticmethod
|
||||
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
|
||||
return xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||
encoded_vocab=tokenizer_data.encoded_vocab,
|
||||
metadata=tokenizer_data.metadata,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XGrammarLogitsProcessor:
|
||||
@ -299,11 +283,16 @@ class XGrammarLogitsProcessor:
|
||||
reasoner: Reasoner | None = None
|
||||
|
||||
ctx: xgr.CompiledGrammar | None = None
|
||||
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
|
||||
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
|
||||
batch_size: int = field(default=1)
|
||||
prefilled: bool = field(default=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer_info = self.config.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
return {'config': self.config, 'reasoner': self.reasoner}
|
||||
|
||||
@ -311,6 +300,8 @@ class XGrammarLogitsProcessor:
|
||||
self.config = state['config']
|
||||
self.reasoner = state['reasoner']
|
||||
|
||||
self.tokenizer_info = GrammarConfig.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
self.ctx = None
|
||||
self.matchers = []
|
||||
self.batch_size = 1
|
||||
@ -352,7 +343,7 @@ class XGrammarLogitsProcessor:
|
||||
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||
]
|
||||
self.token_bitmask = xgr.allocate_token_bitmask(
|
||||
self.batch_size, self.config.vocab_size)
|
||||
self.batch_size, self.tokenizer_info.vocab_size)
|
||||
|
||||
if not self.prefilled:
|
||||
# Have not sampled a token yet
|
||||
|
@ -40,7 +40,7 @@ class StructuredOutputManager:
|
||||
tokenizer_group.ping()
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.vocab_size = len(tokenizer.get_vocab())
|
||||
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
|
Loading…
x
Reference in New Issue
Block a user