Add outlines fallback when JSON schema has enum (#13449)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-02-18 01:49:41 -05:00 committed by GitHub
parent 9915912f7f
commit b53d79983c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 0 deletions

View File

@ -141,6 +141,47 @@ def sample_definition_json_schema():
}
@pytest.fixture
def sample_enum_json_schema():
return {
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["active", "inactive",
"pending"] # Literal values using enum
},
"priority": {
"type": "string",
"enum": ["low", "medium", "high", "critical"]
},
"category": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["bug", "feature", "improvement"]
},
"severity": {
"type": "integer",
"enum": [1, 2, 3, 4,
5] # Enum can also contain numbers
}
},
"required": ["type", "severity"]
},
"flags": {
"type": "array",
"items": {
"type": "string",
"enum": ["urgent", "blocked", "needs_review", "approved"]
}
}
},
"required": ["status", "priority", "category", "flags"]
}
@pytest.fixture
def sample_guided_choice():
return [

View File

@ -146,6 +146,47 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
schema=sample_definition_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
"Create a bug report JSON that fits this schema: "
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_enum_json_schema)
# Additional assertions to verify enum values
assert output_json["status"] in ["active", "inactive", "pending"]
assert output_json["priority"] in ["low", "medium", "high", "critical"]
assert output_json["category"]["type"] in [
"bug", "feature", "improvement"
]
assert output_json["category"]["severity"] in [1, 2, 3, 4, 5]
for flag in output_json["flags"]:
assert flag in ["urgent", "blocked", "needs_review", "approved"]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,

View File

@ -14,6 +14,10 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [