[Bugfix][Frontend] respect provided default guided decoding backend (#15476)
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
This commit is contained in:
parent
d55244df31
commit
98d01d3ce2
@ -1,7 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Tests for the SamplingParams class.
|
"""Tests for the SamplingParams class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen1.5-7B"
|
||||||
|
|
||||||
|
|
||||||
def test_max_tokens_none():
|
def test_max_tokens_none():
|
||||||
@ -9,6 +16,74 @@ def test_max_tokens_none():
|
|||||||
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
|
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@pytest.fixture(scope="module")
|
||||||
import pytest
|
def model_config():
|
||||||
pytest.main([__file__])
|
return ModelConfig(
|
||||||
|
MODEL_NAME,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=MODEL_NAME,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def default_max_tokens():
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_params_from_request_with_no_guided_decoding_backend(
|
||||||
|
model_config, default_max_tokens):
|
||||||
|
# guided_decoding_backend is not present at request level
|
||||||
|
request = ChatCompletionRequest.model_validate({
|
||||||
|
'messages': [{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Hello'
|
||||||
|
}],
|
||||||
|
'model':
|
||||||
|
MODEL_NAME,
|
||||||
|
'response_format': {
|
||||||
|
'type': 'json_object',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens,
|
||||||
|
model_config.logits_processor_pattern,
|
||||||
|
)
|
||||||
|
# we do not expect any backend to be present and the default
|
||||||
|
# guided_decoding_backend at engine level will be used.
|
||||||
|
assert sampling_params.guided_decoding.backend is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
|
||||||
|
[("xgrammar", "xgrammar"),
|
||||||
|
("lm-format-enforcer", "lm-format-enforcer"),
|
||||||
|
("outlines", "outlines")])
|
||||||
|
def test_sampling_params_from_request_with_guided_decoding_backend(
|
||||||
|
request_level_guided_decoding_backend: str, expected: str,
|
||||||
|
model_config, default_max_tokens):
|
||||||
|
|
||||||
|
request = ChatCompletionRequest.model_validate({
|
||||||
|
'messages': [{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Hello'
|
||||||
|
}],
|
||||||
|
'model':
|
||||||
|
MODEL_NAME,
|
||||||
|
'response_format': {
|
||||||
|
'type': 'json_object',
|
||||||
|
},
|
||||||
|
'guided_decoding_backend':
|
||||||
|
request_level_guided_decoding_backend,
|
||||||
|
})
|
||||||
|
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens,
|
||||||
|
model_config.logits_processor_pattern,
|
||||||
|
)
|
||||||
|
# backend correctly identified in resulting sampling_params
|
||||||
|
assert sampling_params.guided_decoding.backend == expected
|
||||||
|
@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
json_schema = self.response_format.json_schema
|
json_schema = self.response_format.json_schema
|
||||||
assert json_schema is not None
|
assert json_schema is not None
|
||||||
self.guided_json = json_schema.json_schema
|
self.guided_json = json_schema.json_schema
|
||||||
if self.guided_decoding_backend is None:
|
|
||||||
self.guided_decoding_backend = "xgrammar"
|
|
||||||
|
|
||||||
guided_decoding = GuidedDecodingParams.from_optional(
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user