Add outlines fallback when JSON schema has enum (#13449)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
9915912f7f
commit
b53d79983c
@ -141,6 +141,47 @@ def sample_definition_json_schema():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_enum_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive",
|
||||
"pending"] # Literal values using enum
|
||||
},
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"enum": ["low", "medium", "high", "critical"]
|
||||
},
|
||||
"category": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["bug", "feature", "improvement"]
|
||||
},
|
||||
"severity": {
|
||||
"type": "integer",
|
||||
"enum": [1, 2, 3, 4,
|
||||
5] # Enum can also contain numbers
|
||||
}
|
||||
},
|
||||
"required": ["type", "severity"]
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": ["urgent", "blocked", "needs_review", "approved"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["status", "priority", "category", "flags"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guided_choice():
|
||||
return [
|
||||
|
@ -146,6 +146,47 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
schema=sample_definition_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_enum_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
outputs = llm.generate(prompts=[
|
||||
"Create a bug report JSON that fits this schema: "
|
||||
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_enum_json_schema)
|
||||
|
||||
# Additional assertions to verify enum values
|
||||
assert output_json["status"] in ["active", "inactive", "pending"]
|
||||
assert output_json["priority"] in ["low", "medium", "high", "critical"]
|
||||
assert output_json["category"]["type"] in [
|
||||
"bug", "feature", "improvement"
|
||||
]
|
||||
assert output_json["category"]["severity"] in [1, 2, 3, 4, 5]
|
||||
for flag in output_json["flags"]:
|
||||
assert flag in ["urgent", "blocked", "needs_review", "approved"]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
|
@ -14,6 +14,10 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Check for enum restrictions
|
||||
if "enum" in obj:
|
||||
return True
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and any(
|
||||
key in obj for key in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user