From 98d01d3ce2a4d06e85348e375e726c40bee0bdf0 Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Wed, 9 Apr 2025 14:11:10 +0200 Subject: [PATCH] [Bugfix][Frontend] respect provided default guided decoding backend (#15476) Signed-off-by: Guillaume Calmettes --- tests/test_sampling_params.py | 81 +++++++++++++++++++++++++++-- vllm/entrypoints/openai/protocol.py | 2 - 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/tests/test_sampling_params.py b/tests/test_sampling_params.py index 40e26ed5..9af810c4 100644 --- a/tests/test_sampling_params.py +++ b/tests/test_sampling_params.py @@ -1,7 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the SamplingParams class. """ + +import pytest + from vllm import SamplingParams +from vllm.config import ModelConfig +from vllm.entrypoints.openai.protocol import ChatCompletionRequest + +MODEL_NAME = "Qwen/Qwen1.5-7B" def test_max_tokens_none(): @@ -9,6 +16,74 @@ def test_max_tokens_none(): SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) -if __name__ == "__main__": - import pytest - pytest.main([__file__]) +@pytest.fixture(scope="module") +def model_config(): + return ModelConfig( + MODEL_NAME, + task="auto", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + +@pytest.fixture(scope="module") +def default_max_tokens(): + return 4096 + + +def test_sampling_params_from_request_with_no_guided_decoding_backend( + model_config, default_max_tokens): + # guided_decoding_backend is not present at request level + request = ChatCompletionRequest.model_validate({ + 'messages': [{ + 'role': 'user', + 'content': 'Hello' + }], + 'model': + MODEL_NAME, + 'response_format': { + 'type': 'json_object', + }, + }) + + sampling_params = request.to_sampling_params( + default_max_tokens, + model_config.logits_processor_pattern, + ) + # we do not expect any backend to be present and the default + # guided_decoding_backend at engine level will be used. + assert sampling_params.guided_decoding.backend is None + + +@pytest.mark.parametrize("request_level_guided_decoding_backend,expected", + [("xgrammar", "xgrammar"), + ("lm-format-enforcer", "lm-format-enforcer"), + ("outlines", "outlines")]) +def test_sampling_params_from_request_with_guided_decoding_backend( + request_level_guided_decoding_backend: str, expected: str, + model_config, default_max_tokens): + + request = ChatCompletionRequest.model_validate({ + 'messages': [{ + 'role': 'user', + 'content': 'Hello' + }], + 'model': + MODEL_NAME, + 'response_format': { + 'type': 'json_object', + }, + 'guided_decoding_backend': + request_level_guided_decoding_backend, + }) + + sampling_params = request.to_sampling_params( + default_max_tokens, + model_config.logits_processor_pattern, + ) + # backend correctly identified in resulting sampling_params + assert sampling_params.guided_decoding.backend == expected diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7cbd9d7c..cbd5f6e5 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel): json_schema = self.response_format.json_schema assert json_schema is not None self.guided_json = json_schema.json_schema - if self.guided_decoding_backend is None: - self.guided_decoding_backend = "xgrammar" guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json,