[Core][V0] Enable regex support with xgrammar (#13228)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
63d2705edb
commit
dc1b4a6f13
@ -286,15 +286,26 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_disable_guided_decoding_fallback(sample_regex, llm):
|
def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||||
|
# see has_xgrammar_unsupported_json_features()
|
||||||
|
unsupported_json = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"example": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 5 # unsupported by xgrammar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(
|
||||||
regex=sample_regex,
|
json=unsupported_json,
|
||||||
backend="xgrammar:no-fallback"))
|
backend="xgrammar:no-fallback"))
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="xgrammar does not support regex guided decoding"):
|
match="xgrammar does not support advanced JSON schema features "
|
||||||
|
"like enums, patterns or numeric ranges."):
|
||||||
llm.generate(prompts="This should fail",
|
llm.generate(prompts="This should fail",
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
|
@ -59,14 +59,9 @@ def maybe_backend_fallback(
|
|||||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||||
xgr_installed)
|
xgr_installed)
|
||||||
|
|
||||||
# xgrammar doesn't support regex, fallback to outlines
|
|
||||||
if guided_params.regex is not None:
|
|
||||||
fallback_or_error(
|
|
||||||
guided_params,
|
|
||||||
"xgrammar does not support regex guided decoding.", "outlines")
|
|
||||||
# xgrammar doesn't support some JSON schema features
|
# xgrammar doesn't support some JSON schema features
|
||||||
elif (guided_params.json is not None
|
if (guided_params.json is not None and
|
||||||
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||||
fallback_or_error(
|
fallback_or_error(
|
||||||
guided_params,
|
guided_params,
|
||||||
"xgrammar does not support advanced JSON schema features like "
|
"xgrammar does not support advanced JSON schema features like "
|
||||||
|
@ -152,6 +152,7 @@ class GrammarConfig:
|
|||||||
grammar_str: str | None = None
|
grammar_str: str | None = None
|
||||||
json_object: bool | None = None
|
json_object: bool | None = None
|
||||||
any_whitespace: bool = True
|
any_whitespace: bool = True
|
||||||
|
regex_str: str | None = None
|
||||||
max_threads: int = 8
|
max_threads: int = 8
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -255,6 +256,13 @@ class GrammarConfig:
|
|||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
tokenizer_data=tokenizer_data,
|
tokenizer_data=tokenizer_data,
|
||||||
)
|
)
|
||||||
|
elif guided_params.regex:
|
||||||
|
return cls(
|
||||||
|
regex_str=guided_params.regex,
|
||||||
|
tokenizer_hash=tokenizer_hash,
|
||||||
|
max_threads=max_threads,
|
||||||
|
tokenizer_data=tokenizer_data,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
||||||
@ -330,6 +338,8 @@ class XGrammarLogitsProcessor:
|
|||||||
self.ctx = compiler\
|
self.ctx = compiler\
|
||||||
.compile_json_schema('{"type": "object"}',
|
.compile_json_schema('{"type": "object"}',
|
||||||
any_whitespace=any_whitespace)
|
any_whitespace=any_whitespace)
|
||||||
|
elif self.config.regex_str:
|
||||||
|
self.ctx = compiler.compile_regex(self.config.regex_str)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid configuration for xgrammar logits processor")
|
"Invalid configuration for xgrammar logits processor")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user