[Fix][Structured Output] using vocab_size to construct matcher (#14868)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
aaaec52ad9
commit
c0efdd655b
@ -200,6 +200,7 @@ steps:
|
|||||||
- pytest -v -s v1/core
|
- pytest -v -s v1/core
|
||||||
- pytest -v -s v1/entrypoints
|
- pytest -v -s v1/entrypoints
|
||||||
- pytest -v -s v1/engine
|
- pytest -v -s v1/engine
|
||||||
|
- pytest -v -s v1/entrypoints
|
||||||
- pytest -v -s v1/sample
|
- pytest -v -s v1/sample
|
||||||
- pytest -v -s v1/worker
|
- pytest -v -s v1/worker
|
||||||
- pytest -v -s v1/structured_output
|
- pytest -v -s v1/structured_output
|
||||||
|
@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
|||||||
lm-format-enforcer >= 0.10.11, < 0.11
|
lm-format-enforcer >= 0.10.11, < 0.11
|
||||||
outlines == 0.1.11
|
outlines == 0.1.11
|
||||||
lark == 1.2.2
|
lark == 1.2.2
|
||||||
xgrammar == 0.1.15; platform_machine == "x86_64" or platform_machine == "aarch64"
|
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
|
||||||
typing_extensions >= 4.10
|
typing_extensions >= 4.10
|
||||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||||
partial-json-parser # used for parsing partial JSON outputs
|
partial-json-parser # used for parsing partial JSON outputs
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -208,8 +209,6 @@ def test_guided_decoding_backend_options():
|
|||||||
|
|
||||||
|
|
||||||
def test_pickle_xgrammar_tokenizer_data():
|
def test_pickle_xgrammar_tokenizer_data():
|
||||||
|
|
||||||
# TODO: move to another test file for xgrammar
|
|
||||||
try:
|
try:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -217,7 +216,11 @@ def test_pickle_xgrammar_tokenizer_data():
|
|||||||
|
|
||||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||||
TokenizerData)
|
TokenizerData)
|
||||||
tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW)
|
tokenizer_data = TokenizerData(
|
||||||
|
metadata=
|
||||||
|
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
|
||||||
|
encoded_vocab=['!', '"', '#', '$', '%'],
|
||||||
|
)
|
||||||
pickled = pickle.dumps(tokenizer_data)
|
pickled = pickle.dumps(tokenizer_data)
|
||||||
|
|
||||||
assert pickled is not None
|
assert pickled is not None
|
||||||
@ -225,4 +228,5 @@ def test_pickle_xgrammar_tokenizer_data():
|
|||||||
depickled: TokenizerData = pickle.loads(pickled)
|
depickled: TokenizerData = pickle.loads(pickled)
|
||||||
|
|
||||||
assert depickled is not None
|
assert depickled is not None
|
||||||
assert depickled.vocab_type == xgr.VocabType.RAW
|
assert json.loads(
|
||||||
|
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value
|
||||||
|
@ -18,9 +18,6 @@ MODELS_TO_TEST = [
|
|||||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Undo after https://github.com/vllm-project/vllm/pull/14868
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
@ -9,7 +9,6 @@ from vllm.model_executor.guided_decoding.reasoner import get_reasoner
|
|||||||
from vllm.model_executor.guided_decoding.utils import (
|
from vllm.model_executor.guided_decoding.utils import (
|
||||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||||
from vllm.platforms import CpuArchEnum
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@ -53,19 +52,12 @@ def maybe_backend_fallback(
|
|||||||
if guided_params.backend_name == "xgrammar":
|
if guided_params.backend_name == "xgrammar":
|
||||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||||
xgr_installed)
|
xgr_installed)
|
||||||
# xgrammar only has x86 wheels for linux, fallback to outlines
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
|
||||||
fallback_or_error(guided_params,
|
|
||||||
"xgrammar is only supported on x86 CPUs.",
|
|
||||||
"outlines")
|
|
||||||
|
|
||||||
# xgrammar doesn't support regex, fallback to outlines
|
# xgrammar doesn't support regex, fallback to outlines
|
||||||
if guided_params.regex is not None:
|
if guided_params.regex is not None:
|
||||||
fallback_or_error(
|
fallback_or_error(
|
||||||
guided_params,
|
guided_params,
|
||||||
"xgrammar does not support regex guided decoding.", "outlines")
|
"xgrammar does not support regex guided decoding.", "outlines")
|
||||||
|
|
||||||
# xgrammar doesn't support some JSON schema features
|
# xgrammar doesn't support some JSON schema features
|
||||||
elif (guided_params.json is not None
|
elif (guided_params.json is not None
|
||||||
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||||
|
@ -9,13 +9,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import TYPE_CHECKING, Any, List
|
from typing import TYPE_CHECKING, Any, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
from xgrammar.base import _core as xgr_core
|
|
||||||
xgr_installed = True
|
xgr_installed = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
xgr_installed = False
|
xgr_installed = False
|
||||||
@ -35,7 +33,6 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: passing batch size to max threads here
|
|
||||||
def get_local_xgrammar_guided_decoding_logits_processor(
|
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams,
|
guided_params: GuidedDecodingParams,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
@ -52,18 +49,8 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TokenizerData:
|
class TokenizerData:
|
||||||
"""Immutable container for cached tokenizer data."""
|
"""Immutable container for cached tokenizer data."""
|
||||||
|
metadata: str
|
||||||
encoded_vocab: list[str] = field(default_factory=list)
|
encoded_vocab: list[str] = field(default_factory=list)
|
||||||
stop_token_ids: list[int] | None = None
|
|
||||||
# These fields are mutually exclusive: `backend_str` is used to create a
|
|
||||||
# TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
|
|
||||||
# used within the constructor of TokenizeInfo
|
|
||||||
backend_str: str | None = None
|
|
||||||
vocab_type: xgr.VocabType | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Check for mutual exclusive
|
|
||||||
assert not (self.backend_str and self.vocab_type), \
|
|
||||||
"backend_str and vocab_type are mutual exclusive"
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerDataCache:
|
class TokenizerDataCache:
|
||||||
@ -71,46 +58,52 @@ class TokenizerDataCache:
|
|||||||
_cache: dict[int, TokenizerData] = {}
|
_cache: dict[int, TokenizerData] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tokenizer_data(cls,
|
def get_tokenizer_data(
|
||||||
tokenizer: PreTrainedTokenizer) -> TokenizerData:
|
cls,
|
||||||
tokenizer_hash = hash(tokenizer)
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
tokenizer_hash: int,
|
||||||
|
vocab_size: int,
|
||||||
|
) -> TokenizerData:
|
||||||
|
|
||||||
if tokenizer_hash not in cls._cache:
|
if tokenizer_hash not in cls._cache:
|
||||||
# Vendored from xgrammar logic since we cannot pickle the tokenizer
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||||
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
|
tokenizer,
|
||||||
|
# NOTE: We will need to use lm_head's vocab_size
|
||||||
|
# to determine correct special_token_ids for this tokenizer.
|
||||||
|
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
metadata = json.loads(tokenizer_info.dump_metadata())
|
||||||
|
|
||||||
|
# Vendored from xgrammar logic to get encoded_vocab
|
||||||
|
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
|
||||||
try:
|
try:
|
||||||
encoded_vocab = [
|
vocab_dict = tokenizer.get_vocab()
|
||||||
token for token, _ in sorted(tokenizer.get_vocab().items(),
|
|
||||||
key=lambda x: x[1])
|
|
||||||
]
|
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot get the vocabulary of the tokenizer "
|
f"Cannot get the vocabulary of the tokenizer "
|
||||||
f"{type(tokenizer)}. The tokenizer should have a "
|
f"{type(tokenizer)}. The tokenizer should have a "
|
||||||
"get_vocab method.") from e
|
"get_vocab method.") from e
|
||||||
|
|
||||||
stop_token_ids = None
|
# maintain tokenizer's indexing
|
||||||
backend_str = ""
|
encoded_vocab = [""] * tokenizer_info.vocab_size
|
||||||
vocab_type = xgr.VocabType.RAW
|
for token, idx in vocab_dict.items():
|
||||||
|
if idx < tokenizer_info.vocab_size:
|
||||||
|
encoded_vocab[idx] = token
|
||||||
|
|
||||||
if stop_token_ids is None and hasattr(
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
tokenizer,
|
|
||||||
"eos_token_id") and tokenizer.eos_token_id is not None:
|
|
||||||
stop_token_ids = [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
|
||||||
backend_str = tokenizer.backend_tokenizer.to_str()
|
|
||||||
vocab_type = None
|
|
||||||
|
|
||||||
elif isinstance(tokenizer, MistralTokenizer):
|
|
||||||
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||||
vocab_type = xgr.VocabType.BYTE_FALLBACK
|
metadata.update({
|
||||||
|
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
|
||||||
|
"add_prefix_space": True
|
||||||
|
})
|
||||||
|
|
||||||
cls._cache[tokenizer_hash] = TokenizerData(
|
cls._cache[tokenizer_hash] = TokenizerData(
|
||||||
encoded_vocab=encoded_vocab,
|
encoded_vocab=encoded_vocab,
|
||||||
stop_token_ids=stop_token_ids,
|
metadata=json.dumps(metadata),
|
||||||
backend_str=backend_str,
|
)
|
||||||
vocab_type=vocab_type)
|
|
||||||
|
|
||||||
return cls._cache[tokenizer_hash]
|
return cls._cache[tokenizer_hash]
|
||||||
|
|
||||||
@ -129,30 +122,15 @@ class GrammarCompilerCache:
|
|||||||
cache_key = str(config.tokenizer_hash)
|
cache_key = str(config.tokenizer_hash)
|
||||||
|
|
||||||
if cache_key not in cls._cache:
|
if cache_key not in cls._cache:
|
||||||
assert config.tokenizer_data is not None
|
|
||||||
assert config.tokenizer_data.encoded_vocab is not None
|
|
||||||
|
|
||||||
config_data = config.tokenizer_data
|
config_data = config.tokenizer_data
|
||||||
|
|
||||||
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
||||||
# tokenizer_data is created and cached. This data is used to build
|
# tokenizer_data is created and cached. This data is used to build
|
||||||
# a tokenizer_info and create an xgrammar compiler.
|
# a tokenizer_info and create an xgrammar compiler.
|
||||||
# - If tokenizer_data has backend_str set, use
|
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||||
# xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
|
encoded_vocab=config_data.encoded_vocab,
|
||||||
# - Otherwise, use the default constructor with vocab_type.
|
metadata=config_data.metadata,
|
||||||
# - xgr_core.TokenizerInfo.from_huggingface !=
|
)
|
||||||
# xgr.TokenizerInfo.from_huggingface.
|
|
||||||
if config_data.backend_str:
|
|
||||||
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
|
||||||
xgr_core.TokenizerInfo.from_huggingface(
|
|
||||||
config_data.encoded_vocab, config_data.backend_str,
|
|
||||||
config.vocab_size, config_data.stop_token_ids))
|
|
||||||
else:
|
|
||||||
tokenizer_info = xgr.TokenizerInfo(
|
|
||||||
config_data.encoded_vocab,
|
|
||||||
config_data.vocab_type,
|
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
stop_token_ids=config_data.stop_token_ids)
|
|
||||||
cls._cache[cache_key] = xgr.GrammarCompiler(
|
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||||
tokenizer_info, max_threads=config.max_threads)
|
tokenizer_info, max_threads=config.max_threads)
|
||||||
|
|
||||||
@ -163,13 +141,12 @@ class GrammarCompilerCache:
|
|||||||
class GrammarConfig:
|
class GrammarConfig:
|
||||||
"""Serializable configuration for grammar compilation"""
|
"""Serializable configuration for grammar compilation"""
|
||||||
tokenizer_hash: int
|
tokenizer_hash: int
|
||||||
vocab_size: int
|
tokenizer_data: TokenizerData
|
||||||
json_str: str | None = None
|
json_str: str | None = None
|
||||||
grammar_str: str | None = None
|
grammar_str: str | None = None
|
||||||
json_object: bool | None = None
|
json_object: bool | None = None
|
||||||
any_whitespace: bool = True
|
any_whitespace: bool = True
|
||||||
max_threads: int = 8
|
max_threads: int = 8
|
||||||
tokenizer_data: TokenizerData | None = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_guided_params(cls,
|
def from_guided_params(cls,
|
||||||
@ -179,7 +156,11 @@ class GrammarConfig:
|
|||||||
max_threads: int = 8) -> GrammarConfig:
|
max_threads: int = 8) -> GrammarConfig:
|
||||||
|
|
||||||
tokenizer_hash = hash(tokenizer)
|
tokenizer_hash = hash(tokenizer)
|
||||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
|
||||||
|
tokenizer,
|
||||||
|
tokenizer_hash=tokenizer_hash,
|
||||||
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
if guided_params.json:
|
if guided_params.json:
|
||||||
if not isinstance(guided_params.json, str):
|
if not isinstance(guided_params.json, str):
|
||||||
@ -218,7 +199,6 @@ class GrammarConfig:
|
|||||||
raise ValueError(str(err)) from err
|
raise ValueError(str(err)) from err
|
||||||
|
|
||||||
return cls(json_str=json_str,
|
return cls(json_str=json_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
tokenizer_data=tokenizer_data,
|
tokenizer_data=tokenizer_data,
|
||||||
@ -246,14 +226,12 @@ class GrammarConfig:
|
|||||||
raise ValueError(str(err)) from err
|
raise ValueError(str(err)) from err
|
||||||
|
|
||||||
return cls(grammar_str=grammar_str,
|
return cls(grammar_str=grammar_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
tokenizer_data=tokenizer_data)
|
tokenizer_data=tokenizer_data)
|
||||||
elif guided_params.json_object:
|
elif guided_params.json_object:
|
||||||
return cls(
|
return cls(
|
||||||
json_object=True,
|
json_object=True,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
tokenizer_data=tokenizer_data,
|
tokenizer_data=tokenizer_data,
|
||||||
@ -267,7 +245,6 @@ class GrammarConfig:
|
|||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
grammar_str=choice_str,
|
grammar_str=choice_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
|
||||||
tokenizer_hash=tokenizer_hash,
|
tokenizer_hash=tokenizer_hash,
|
||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
tokenizer_data=tokenizer_data,
|
tokenizer_data=tokenizer_data,
|
||||||
@ -291,6 +268,13 @@ class GrammarConfig:
|
|||||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||||
return grammar
|
return grammar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
|
||||||
|
return xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||||
|
encoded_vocab=tokenizer_data.encoded_vocab,
|
||||||
|
metadata=tokenizer_data.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XGrammarLogitsProcessor:
|
class XGrammarLogitsProcessor:
|
||||||
@ -299,11 +283,16 @@ class XGrammarLogitsProcessor:
|
|||||||
reasoner: Reasoner | None = None
|
reasoner: Reasoner | None = None
|
||||||
|
|
||||||
ctx: xgr.CompiledGrammar | None = None
|
ctx: xgr.CompiledGrammar | None = None
|
||||||
|
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
|
||||||
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||||
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
|
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
|
||||||
batch_size: int = field(default=1)
|
batch_size: int = field(default=1)
|
||||||
prefilled: bool = field(default=False)
|
prefilled: bool = field(default=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.tokenizer_info = self.config.tokenizer_info(
|
||||||
|
self.config.tokenizer_data)
|
||||||
|
|
||||||
def __getstate__(self) -> dict[str, Any]:
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
return {'config': self.config, 'reasoner': self.reasoner}
|
return {'config': self.config, 'reasoner': self.reasoner}
|
||||||
|
|
||||||
@ -311,6 +300,8 @@ class XGrammarLogitsProcessor:
|
|||||||
self.config = state['config']
|
self.config = state['config']
|
||||||
self.reasoner = state['reasoner']
|
self.reasoner = state['reasoner']
|
||||||
|
|
||||||
|
self.tokenizer_info = GrammarConfig.tokenizer_info(
|
||||||
|
self.config.tokenizer_data)
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
self.matchers = []
|
self.matchers = []
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
@ -352,7 +343,7 @@ class XGrammarLogitsProcessor:
|
|||||||
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||||
]
|
]
|
||||||
self.token_bitmask = xgr.allocate_token_bitmask(
|
self.token_bitmask = xgr.allocate_token_bitmask(
|
||||||
self.batch_size, self.config.vocab_size)
|
self.batch_size, self.tokenizer_info.vocab_size)
|
||||||
|
|
||||||
if not self.prefilled:
|
if not self.prefilled:
|
||||||
# Have not sampled a token yet
|
# Have not sampled a token yet
|
||||||
|
@ -40,7 +40,7 @@ class StructuredOutputManager:
|
|||||||
tokenizer_group.ping()
|
tokenizer_group.ping()
|
||||||
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||||
self.vocab_size = len(tokenizer.get_vocab())
|
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||||
|
Loading…
x
Reference in New Issue
Block a user