[Bugfix][Frontend] respect provided default guided decoding backend (#15476)
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
This commit is contained in:
parent
d55244df31
commit
98d01d3ce2
@ -1,7 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the SamplingParams class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen1.5-7B"
|
||||
|
||||
|
||||
def test_max_tokens_none():
|
||||
@ -9,6 +16,74 @@ def test_max_tokens_none():
|
||||
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@pytest.fixture(scope="module")
|
||||
def model_config():
|
||||
return ModelConfig(
|
||||
MODEL_NAME,
|
||||
task="auto",
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_max_tokens():
|
||||
return 4096
|
||||
|
||||
|
||||
def test_sampling_params_from_request_with_no_guided_decoding_backend(
|
||||
model_config, default_max_tokens):
|
||||
# guided_decoding_backend is not present at request level
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
MODEL_NAME,
|
||||
'response_format': {
|
||||
'type': 'json_object',
|
||||
},
|
||||
})
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
model_config.logits_processor_pattern,
|
||||
)
|
||||
# we do not expect any backend to be present and the default
|
||||
# guided_decoding_backend at engine level will be used.
|
||||
assert sampling_params.guided_decoding.backend is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
|
||||
[("xgrammar", "xgrammar"),
|
||||
("lm-format-enforcer", "lm-format-enforcer"),
|
||||
("outlines", "outlines")])
|
||||
def test_sampling_params_from_request_with_guided_decoding_backend(
|
||||
request_level_guided_decoding_backend: str, expected: str,
|
||||
model_config, default_max_tokens):
|
||||
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
MODEL_NAME,
|
||||
'response_format': {
|
||||
'type': 'json_object',
|
||||
},
|
||||
'guided_decoding_backend':
|
||||
request_level_guided_decoding_backend,
|
||||
})
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
model_config.logits_processor_pattern,
|
||||
)
|
||||
# backend correctly identified in resulting sampling_params
|
||||
assert sampling_params.guided_decoding.backend == expected
|
||||
|
@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
json_schema = self.response_format.json_schema
|
||||
assert json_schema is not None
|
||||
self.guided_json = json_schema.json_schema
|
||||
if self.guided_decoding_backend is None:
|
||||
self.guided_decoding_backend = "xgrammar"
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||
|
Loading…
x
Reference in New Issue
Block a user