[Bugfix][v1] xgrammar structured output supports Enum. (#15594)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-03-28 21:14:53 +08:00 committed by GitHub
parent 91276c5721
commit 3b00ff9138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 4 deletions

View File

@ -4,10 +4,12 @@ from __future__ import annotations
import json
import re
from enum import Enum
from typing import Any
import jsonschema
import pytest
from pydantic import BaseModel
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
@ -390,3 +392,54 @@ def test_guided_choice_completion(
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_with_enum(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)

View File

@ -26,10 +26,6 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj