[Bugfix][V0] XGrammar structured output supports Enum (#15878)

Signed-off-by: Leon Seidel <leon.seidel@fau.de>
This commit is contained in:
leon-seidel 2025-04-08 00:38:25 +02:00 committed by GitHub
parent fad6e2538e
commit 24f1c01e0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 4 deletions

View File

@ -3,9 +3,11 @@
import json import json
import re import re
import weakref import weakref
from enum import Enum
import jsonschema import jsonschema
import pytest import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)

View File

@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj: if "pattern" in obj:
return True return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and any(
key in obj for key in [ key in obj for key in [