Add outlines fallback when JSON schema has enum (#13449)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
9915912f7f
commit
b53d79983c
@ -141,6 +141,47 @@ def sample_definition_json_schema():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_enum_json_schema():
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["active", "inactive",
|
||||||
|
"pending"] # Literal values using enum
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["low", "medium", "high", "critical"]
|
||||||
|
},
|
||||||
|
"category": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["bug", "feature", "improvement"]
|
||||||
|
},
|
||||||
|
"severity": {
|
||||||
|
"type": "integer",
|
||||||
|
"enum": [1, 2, 3, 4,
|
||||||
|
5] # Enum can also contain numbers
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["type", "severity"]
|
||||||
|
},
|
||||||
|
"flags": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["urgent", "blocked", "needs_review", "approved"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["status", "priority", "category", "flags"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_guided_choice():
|
def sample_guided_choice():
|
||||||
return [
|
return [
|
||||||
|
@ -146,6 +146,47 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
|||||||
schema=sample_definition_json_schema)
|
schema=sample_definition_json_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
|
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
sampling_params = SamplingParams(temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
json=sample_enum_json_schema,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
outputs = llm.generate(prompts=[
|
||||||
|
"Create a bug report JSON that fits this schema: "
|
||||||
|
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json,
|
||||||
|
schema=sample_enum_json_schema)
|
||||||
|
|
||||||
|
# Additional assertions to verify enum values
|
||||||
|
assert output_json["status"] in ["active", "inactive", "pending"]
|
||||||
|
assert output_json["priority"] in ["low", "medium", "high", "critical"]
|
||||||
|
assert output_json["category"]["type"] in [
|
||||||
|
"bug", "feature", "improvement"
|
||||||
|
]
|
||||||
|
assert output_json["category"]["severity"] in [1, 2, 3, 4, 5]
|
||||||
|
for flag in output_json["flags"]:
|
||||||
|
assert flag in ["urgent", "blocked", "needs_review", "approved"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||||
|
@ -14,6 +14,10 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
|||||||
if "pattern" in obj:
|
if "pattern" in obj:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Check for enum restrictions
|
||||||
|
if "enum" in obj:
|
||||||
|
return True
|
||||||
|
|
||||||
# Check for numeric ranges
|
# Check for numeric ranges
|
||||||
if obj.get("type") in ("integer", "number") and any(
|
if obj.get("type") in ("integer", "number") and any(
|
||||||
key in obj for key in [
|
key in obj for key in [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user