From b53d79983c273b2775456d99c0e0890aea073512 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 18 Feb 2025 01:49:41 -0500 Subject: [PATCH] Add outlines fallback when JSON schema has enum (#13449) Signed-off-by: mgoin --- tests/entrypoints/conftest.py | 41 +++++++++++++++++++ tests/entrypoints/llm/test_guided_generate.py | 41 +++++++++++++++++++ vllm/model_executor/guided_decoding/utils.py | 4 ++ 3 files changed, 86 insertions(+) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index b00e168d..3b596ea3 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -141,6 +141,47 @@ def sample_definition_json_schema(): } +@pytest.fixture +def sample_enum_json_schema(): + return { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", + "pending"] # Literal values using enum + }, + "priority": { + "type": "string", + "enum": ["low", "medium", "high", "critical"] + }, + "category": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["bug", "feature", "improvement"] + }, + "severity": { + "type": "integer", + "enum": [1, 2, 3, 4, + 5] # Enum can also contain numbers + } + }, + "required": ["type", "severity"] + }, + "flags": { + "type": "array", + "items": { + "type": "string", + "enum": ["urgent", "blocked", "needs_review", "approved"] + } + } + }, + "required": ["status", "priority", "category", "flags"] + } + + @pytest.fixture def sample_guided_choice(): return [ diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 932a35a9..01d2c170 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -146,6 +146,47 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, schema=sample_definition_json_schema) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_enum_json_completion(sample_enum_json_schema, llm, + guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_enum_json_schema, + backend=guided_decoding_backend)) + outputs = llm.generate(prompts=[ + "Create a bug report JSON that fits this schema: " + f"{sample_enum_json_schema}. Make it for a high priority critical bug." + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_enum_json_schema) + + # Additional assertions to verify enum values + assert output_json["status"] in ["active", "inactive", "pending"] + assert output_json["priority"] in ["low", "medium", "high", "critical"] + assert output_json["category"]["type"] in [ + "bug", "feature", "improvement" + ] + assert output_json["category"]["severity"] in [1, 2, 3, 4, 5] + for flag in output_json["flags"]: + assert flag in ["urgent", "blocked", "needs_review", "approved"] + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index 87ef4535..c3c0378e 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -14,6 +14,10 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool: if "pattern" in obj: return True + # Check for enum restrictions + if "enum" in obj: + return True + # Check for numeric ranges if obj.get("type") in ("integer", "number") and any( key in obj for key in [