diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 0929f990..337df451 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -2,7 +2,7 @@ import pytest -from vllm.v1.structured_output.utils import ( +from vllm.v1.structured_output.backend_xgrammar import ( has_xgrammar_unsupported_json_features) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6d3290f1..396fe25e 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -22,8 +22,8 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) -from vllm.v1.structured_output.utils import ( - validate_structured_output_request_xgrammar) +from vllm.v1.structured_output.backend_xgrammar import ( + validate_xgrammar_grammar) class Processor: @@ -165,7 +165,7 @@ class Processor: # Request content validation if engine_level_backend.startswith("xgrammar"): # xgrammar with no fallback - validate_structured_output_request_xgrammar(params) + validate_xgrammar_grammar(params) params.guided_decoding.backend = engine_level_backend elif engine_level_backend == "auto": # "auto" is an opt-in to opinionated behavior where we try to @@ -173,7 +173,7 @@ class Processor: # default as it is less predictable and subject to change # between releases as feature support changes. try: - validate_structured_output_request_xgrammar(params) + validate_xgrammar_grammar(params) params.guided_decoding.backend = "xgrammar" except ValueError: # The request includes some jsonschema feature(s) that diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 83f2c643..c9839bd7 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,19 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 +import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch import vllm.envs from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, StructuredOutputOptions) +from vllm.v1.structured_output.utils import (choice_as_grammar, + convert_lark_to_ebnf, + grammar_is_likely_lark) if TYPE_CHECKING: import xgrammar as xgr @@ -156,3 +161,112 @@ class XgrammarGrammar(StructuredOutputGrammar): def reset(self): self.num_processed_tokens = 0 self.matcher.reset() + + +def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict[str, Any]) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj + for key in ("minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf")): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any( + key in obj + for key in ("uniqueItems", "contains", "minContains", + "maxContains", "minItems", "maxItems")): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and "format" in obj: + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any( + key in obj for key in ("minProperties", "maxProperties", + "propertyNames", "patternProperties")): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: + """Validate that the request is supported by structured output. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + + if gd_params.regex: + try: + xgr.Grammar.from_regex(gd_params.regex) + except Exception as err: + raise ValueError("Failed to transform regex into a grammar: " + f"{err}") from err + + if gd_params.choice: + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgr.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") + return + + if gd_params.grammar: + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark + try: + # parse the grammar, but we aren't compiling it. + xgr.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 56eed959..f33f4972 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -2,67 +2,7 @@ from __future__ import annotations -import json import re -from typing import TYPE_CHECKING, Any - -from vllm.sampling_params import SamplingParams -from vllm.utils import LazyLoader - -if TYPE_CHECKING: - import xgrammar as xgr -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - - -def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: - """Check if JSON schema contains features unsupported by xgrammar.""" - - def check_object(obj: dict[str, Any]) -> bool: - if not isinstance(obj, dict): - return False - - # Check for pattern restrictions - if "pattern" in obj: - return True - - # Check for numeric ranges - if obj.get("type") in ("integer", "number") and any( - key in obj - for key in ("minimum", "maximum", "exclusiveMinimum", - "exclusiveMaximum", "multipleOf")): - return True - - # Check for array unsupported keywords - if obj.get("type") == "array" and any( - key in obj - for key in ("uniqueItems", "contains", "minContains", - "maxContains", "minItems", "maxItems")): - return True - - # Unsupported keywords for strings - if obj.get("type") == "string" and "format" in obj: - return True - - # Unsupported keywords for objects - if obj.get("type") == "object" and any( - key in obj for key in ("minProperties", "maxProperties", - "propertyNames", "patternProperties")): - return True - - # Recursively check all nested objects and arrays - for value in obj.values(): - if isinstance(value, dict): - if check_object(value): - return True - elif isinstance(value, list): - for item in value: - if isinstance(item, dict) and check_object(item): - return True - - return False - - return check_object(schema) def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -232,63 +172,3 @@ def choice_as_grammar(choice: list[str]) -> str: escaped_choices = (escape_ebnf_string(c) for c in choice) grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) return grammar - - -def validate_structured_output_request_xgrammar( - sampling_params: SamplingParams) -> None: - """Validate that the request is supported by structured output. - - Raises ValueError if the request is not supported. - """ - if sampling_params.guided_decoding is None: - return - - gd_params = sampling_params.guided_decoding - - if gd_params.regex: - try: - xgr.Grammar.from_regex(gd_params.regex) - except Exception as err: - raise ValueError("Failed to transform regex into a grammar: " - f"{err}") from err - - if gd_params.choice: - choice_grammar = choice_as_grammar(gd_params.choice) - try: - xgr.Grammar.from_ebnf(choice_grammar) - except Exception as err: - raise ValueError("Failed to transform choices into a grammar: " - "{err}") from err - gd_params.choice = None - gd_params.grammar = choice_grammar - return - - if gd_params.json: - if isinstance(gd_params.json, str): - try: - schema = json.loads(gd_params.json) - except json.JSONDecodeError as e: - raise ValueError("Invalid JSON grammar specification.") from e - else: - schema = gd_params.json - - if has_xgrammar_unsupported_json_features(schema): - raise ValueError("The provided JSON schema contains features not " - "supported by xgrammar.") - return - - if gd_params.grammar: - if grammar_is_likely_lark(gd_params.grammar): - # xgrammar supports EBNF grammars only - try: - gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) - except ValueError as e: - raise ValueError( - "Failed to convert the grammar from Lark to EBNF. ") from e - - # Test parsing EBNF grammar, possibly already converted from Lark - try: - # parse the grammar, but we aren't compiling it. - xgr.Grammar.from_ebnf(gd_params.grammar) - except Exception as e: - raise ValueError("Invalid grammar specification.") from e