[Bugfix][Frontend] respect provided default guided decoding backend (#15476)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
This commit is contained in:
Guillaume Calmettes 2025-04-09 14:11:10 +02:00 committed by GitHub
parent d55244df31
commit 98d01d3ce2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 5 deletions

View File

@ -1,7 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for the SamplingParams class. """Tests for the SamplingParams class.
""" """
import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
MODEL_NAME = "Qwen/Qwen1.5-7B"
def test_max_tokens_none(): def test_max_tokens_none():
@ -9,6 +16,74 @@ def test_max_tokens_none():
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
if __name__ == "__main__": @pytest.fixture(scope="module")
import pytest def model_config():
pytest.main([__file__]) return ModelConfig(
MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)
@pytest.fixture(scope="module")
def default_max_tokens():
return 4096
def test_sampling_params_from_request_with_no_guided_decoding_backend(
model_config, default_max_tokens):
# guided_decoding_backend is not present at request level
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# we do not expect any backend to be present and the default
# guided_decoding_backend at engine level will be used.
assert sampling_params.guided_decoding.backend is None
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"),
("lm-format-enforcer", "lm-format-enforcer"),
("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str,
model_config, default_max_tokens):
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
'guided_decoding_backend':
request_level_guided_decoding_backend,
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# backend correctly identified in resulting sampling_params
assert sampling_params.guided_decoding.backend == expected

View File

@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema json_schema = self.response_format.json_schema
assert json_schema is not None assert json_schema is not None
self.guided_json = json_schema.json_schema self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
guided_decoding = GuidedDecodingParams.from_optional( guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json, json=self._get_guided_json_from_tool() or self.guided_json,