[Bugfix][v1] xgrammar structured output supports Enum. (#15594)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
91276c5721
commit
3b00ff9138
@ -4,10 +4,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
@ -390,3 +392,54 @@ def test_guided_choice_completion(
|
|||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
assert generated_text in sample_guided_choice
|
assert generated_text in sample_guided_choice
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
class CarType(str, Enum):
|
||||||
|
sedan = "sedan"
|
||||||
|
suv = "SUV"
|
||||||
|
truck = "Truck"
|
||||||
|
coupe = "Coupe"
|
||||||
|
|
||||||
|
|
||||||
|
class CarDescription(BaseModel):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
car_type: CarType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
||||||
|
def test_guided_json_completion_with_enum(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
guided_decoding_backend: str,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
json_schema = CarDescription.model_json_schema()
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts="Generate a JSON with the brand, model and car_type of"
|
||||||
|
"the most iconic car from the 90's",
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
@ -26,10 +26,6 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
|||||||
if "pattern" in obj:
|
if "pattern" in obj:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for enum restrictions
|
|
||||||
if "enum" in obj:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for numeric ranges
|
# Check for numeric ranges
|
||||||
if obj.get("type") in ("integer", "number") and any(
|
if obj.get("type") in ("integer", "number") and any(
|
||||||
key in obj
|
key in obj
|
||||||
|
Loading…
x
Reference in New Issue
Block a user