[Bugfix][V0] XGrammar structured output supports Enum (#15878)
Signed-off-by: Leon Seidel <leon.seidel@fau.de>
This commit is contained in:
parent
fad6e2538e
commit
24f1c01e0f
@ -3,9 +3,11 @@
|
||||
import json
|
||||
import re
|
||||
import weakref
|
||||
from enum import Enum
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
truck = "Truck"
|
||||
coupe = "Coupe"
|
||||
|
||||
|
||||
class CarDescription(BaseModel):
|
||||
brand: str
|
||||
model: str
|
||||
car_type: CarType
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Check for enum restrictions
|
||||
if "enum" in obj:
|
||||
return True
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and any(
|
||||
key in obj for key in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user