[V1] guidance backend for structured output + auto fallback mode (#14779)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Loc Huynh <jc1da.3011@gmail.com>
Co-authored-by: Michal Moskal <michal@moskal.me>
This commit is contained in:
Russell Bryant 2025-03-25 00:02:33 -04:00 committed by GitHub
parent 10b34e36b9
commit a09ad90a72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 344 additions and 110 deletions

View File

@ -18,7 +18,7 @@ pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0 prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11 lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11 outlines == 0.1.11
lark == 1.2.2 lark == 1.2.2
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64" xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"

View File

@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
MODELS_TO_TEST = [ MODELS_TO_TEST = [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410" "Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
] ]
@ -30,12 +30,13 @@ def test_guided_json_completion(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=1.0, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(json=sample_json_schema))
json=sample_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[ outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile " f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}" f"that fits this schema: {sample_json_schema}"
@ -111,13 +112,14 @@ def test_guided_json_object(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=1.0, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100, max_tokens=100,
n=2, n=2,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(json_object=True))
json_object=True,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with " prompts=("Generate a JSON object with curly braces for a person with "
@ -137,12 +139,20 @@ def test_guided_json_object(
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) allowed_types: tuple[type, ...] = (dict, )
if guided_decoding_backend == "xgrammar":
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
# {"type": "object"}, # but we need this fix in a release
# first: https://github.com/mlc-ai/xgrammar/pull/264
allowed_types = (dict, list)
assert isinstance(parsed_json, allowed_types)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1 + ["auto"])
@pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_unsupported_schema( def test_guided_json_unsupported_schema(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@ -151,12 +161,14 @@ def test_guided_json_unsupported_schema(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=1.0, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
json=unsupported_json_schema, if guided_decoding_backend == "xgrammar":
backend=guided_decoding_backend))
with pytest.raises(ValueError, with pytest.raises(ValueError,
match="The provided JSON schema contains features " match="The provided JSON schema contains features "
"not supported by xgrammar."): "not supported by xgrammar."):
@ -166,6 +178,26 @@ def test_guided_json_unsupported_schema(
] * 2, ] * 2,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
else:
# This should work for both "guidance" and "auto".
outputs = llm.generate(
prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
print(generated_text)
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
grammar=sample_sql_ebnf,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from " prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"), "table_1 where it is equal to 1"),
@ -222,13 +255,14 @@ def test_guided_grammar_lark(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
grammar=sample_sql_lark,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from " prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"), "table_1 where it is equal to 1"),
@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
grammar="not a grammar", with pytest.raises(ValueError, match="Failed to convert the grammar "):
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Failed to convert the grammar "
"from Lark to EBNF."):
llm.generate( llm.generate(
prompts=("Generate a sql statement that selects col_1 from " prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"), "table_1 where it is equal to 1"),
@ -298,12 +331,13 @@ def test_guided_regex(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(regex=sample_regex))
regex=sample_regex,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts=[ prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}" f"Give an example IPv4 address with this regex: {sample_regex}"
@ -335,12 +369,13 @@ def test_guided_choice_completion(
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024) llm = LLM(model=model_name,
sampling_params = SamplingParams(temperature=0.8, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
choice=sample_guided_choice,
backend=guided_decoding_backend))
outputs = llm.generate( outputs = llm.generate(
prompts="The best language for type-safe systems programming is ", prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params, sampling_params=sampling_params,

View File

@ -2800,12 +2800,17 @@ class DecodingConfig:
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):
valid_guided_backends = [ v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance' 'outlines', 'lm-format-enforcer', 'xgrammar'
] ]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
backend = GuidedDecodingParams( backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name backend=self.guided_decoding_backend).backend_name
if envs.VLLM_USE_V1:
valid_guided_backends = v1_valid_guided_backends
else:
valid_guided_backends = v0_valid_guided_backends
if backend not in valid_guided_backends: if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}'," raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}") f" must be one of {valid_guided_backends}")

View File

@ -391,16 +391,13 @@ class EngineArgs:
default='xgrammar', default='xgrammar',
help='Which engine will be used for guided decoding' help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support ' ' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines, ' 'https://github.com/mlc-ai/xgrammar and '
'https://github.com/mlc-ai/xgrammar, and ' 'https://github.com/guidance-ai/llguidance.'
'https://github.com/noamgat/lm-format-enforcer.' 'Valid backend values are "xgrammar", "guidance", and "auto". '
' Can be overridden per request via guided_decoding_backend' 'With "auto", we will make opinionated choices based on request'
' parameter.\n' 'contents and what the backend libraries currently support, so '
'Backend-specific options can be supplied in a comma-separated ' 'the behavior is subject to change in each release. '
'list following a colon after the backend name. Valid backends and ' 'The default is xgrammar.')
'all available options are: [xgrammar:no-fallback, '
'xgrammar:disable-any-whitespace, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=nullable_str, type=nullable_str,
@ -1539,9 +1536,9 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# Only support Xgrammar for guided decoding so far. # Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [ SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace" "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
] ]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend", _raise_or_fallback(feature_name="--guided-decoding-backend",

View File

@ -4,7 +4,6 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional, Union from typing import Optional, Union
import vllm.platforms
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.utils import validate_structured_output_request from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
validate_structured_output_request_xgrammar)
class Processor: class Processor:
@ -120,7 +122,9 @@ class Processor:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"] supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends: if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is " raise ValueError(f"Only {supported_backends} structured output is "
@ -134,10 +138,31 @@ class Processor:
else: else:
params.guided_decoding.backend = engine_level_backend params.guided_decoding.backend = engine_level_backend
if vllm.platforms.current_platform.is_tpu(): # Request content validation
raise ValueError("Structured output is not supported on TPU.")
validate_structured_output_request(params) if engine_level_backend == "xgrammar":
# xgrammar with no fallback
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance":
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
def process_inputs( def process_inputs(
self, self,

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar) StructuredOutputGrammar)
@ -50,6 +51,8 @@ class StructuredOutputManager:
XgrammarBackend) XgrammarBackend)
self.backend = XgrammarBackend(self.vllm_config) self.backend = XgrammarBackend(self.vllm_config)
elif backend_name == "guidance":
self.backend = GuidanceBackend(self.vllm_config)
else: else:
raise ValueError( raise ValueError(
f"Unsupported structured output backend: {backend_name}") f"Unsupported structured output backend: {backend_name}")

View File

@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
from vllm.v1.structured_output.request import get_structured_output_key
if TYPE_CHECKING:
import llguidance
import llguidance.hf as llguidance_hf
import llguidance.torch as llguidance_torch
else:
llguidance = LazyLoader("llguidance", globals(), "llguidance")
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
llguidance_torch = LazyLoader("llguidance.torch", globals(),
"llguidance.torch")
logger = init_logger(__name__)
class GuidanceBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec)
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
r = GuidanceGrammar(
ll_matcher=ll_matcher,
ll_tokenizer=self.ll_tokenizer,
vocab_size=self.vocab_size,
)
r.check_error()
return r
def allocate_token_bitmask(self, max_num_seqs: int):
return llguidance_torch.allocate_token_bitmask(
max_num_seqs, self.ll_tokenizer.vocab_size)
@dataclass
class GuidanceGrammar(StructuredOutputGrammar):
ll_matcher: llguidance.LLMatcher
ll_tokenizer: llguidance.LLTokenizer
vocab_size: int
printed_error: bool = False
terminated: bool = False
def check_error(self):
if not self.printed_error:
err = self.ll_matcher.get_error()
if err:
self.printed_error = True
logger.warning("LLMatcher error: %s", err)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the parser.
Returns True if the parser was advanced successfully.
Returns False if the parser failed to advance.
"""
if self.ll_tokenizer.eos_token in tokens:
self.terminated = True
if self.ll_matcher.is_stopped():
return True
# TODO - Add jump decoding support in the future:
# self.ll_matcher.compute_ff_bytes() - this should always work
# self.ll_matcher.compute_ff_tokens() - this only works for
# "canonical" tokenizers
# For conversion between the two, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
r = self.ll_matcher.consume_tokens(tokens)
self.check_error()
return r
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped
# or otherwise in an error state
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
self.check_error()
def is_terminated(self) -> bool:
return self.terminated
def reset(self):
# This method may be not needed anymore? TODO
self.ll_matcher.reset()
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
grammar_spec: str) -> str:
if request_type == StructuredOutputOptions.JSON:
# TODO: make whitespace_flexible configurable
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec, defaults={
"whitespace_flexible": True,
})
elif request_type == StructuredOutputOptions.JSON_OBJECT:
return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}', defaults={
"whitespace_flexible": True,
})
else:
if request_type == StructuredOutputOptions.REGEX:
tp = "regex"
elif request_type == StructuredOutputOptions.GRAMMAR:
tp = "grammar"
elif request_type == StructuredOutputOptions.CHOICE:
tp = "choice"
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError("grammar is not of valid supported types. "
f"({request_type!s})")
return llguidance.grammar_from(tp, grammar_spec)
def validate_guidance_grammar(
sampling_params: SamplingParams,
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm,
tokenizer=tokenizer)
if err:
raise ValueError(f"Grammar error: {err}")

View File

@ -53,7 +53,12 @@ class StructuredOutputRequest:
@functools.cached_property @functools.cached_property
def structured_output_key(self) -> StructuredOutputKey: def structured_output_key(self) -> StructuredOutputKey:
params = self.sampling_params.guided_decoding return get_structured_output_key(self.sampling_params)
def get_structured_output_key(
sampling_params: SamplingParams) -> StructuredOutputKey:
params = sampling_params.guided_decoding
assert params is not None, "params can't be None." assert params is not None, "params can't be None."
if params.json is not None: if params.json is not None:
if not isinstance(params.json, str): if not isinstance(params.json, str):

View File

@ -239,7 +239,7 @@ def choice_as_grammar(choice: list[str]) -> str:
return grammar return grammar
def validate_structured_output_request( def validate_structured_output_request_xgrammar(
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
"""Validate that the request is supported by structured output. """Validate that the request is supported by structured output.