Default to generation_config from model (#12622)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-08 07:46:15 +01:00 committed by GitHub
parent 3b9c6c6947
commit 47512b3200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 27 additions and 26 deletions

View File

@ -20,7 +20,7 @@ NUM_CONCURRENT = 500
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
EXPECTED_VALUE = 0.54
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default

View File

@ -38,6 +38,7 @@ class MockModelConfig:
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
encoder_config = None
generation_config: str = "auto"
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}

View File

@ -289,7 +289,7 @@ def test_uses_mrope(model_id, uses_mrope):
def test_generation_config_loading():
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
# When set generation_config to None, the default generation config
# When set generation_config to "vllm", the default generation config
# will not be loaded.
model_config = ModelConfig(model_id,
task="auto",
@ -298,7 +298,7 @@ def test_generation_config_loading():
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None)
generation_config="vllm")
assert model_config.get_diff_sampling_param() == {}
# When set generation_config to "auto", the default generation config
@ -340,7 +340,7 @@ def test_generation_config_loading():
assert model_config.get_diff_sampling_param() == override_result
# When generation_config is set to None and override_generation_config
# When generation_config is set to "vllm" and override_generation_config
# is set, the override_generation_config should be used directly.
model_config = ModelConfig(
model_id,
@ -350,7 +350,7 @@ def test_generation_config_loading():
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None,
generation_config="vllm",
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config

View File

@ -255,7 +255,7 @@ class ModelConfig:
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
@ -951,7 +951,7 @@ class ModelConfig:
return self.multimodal_config
def try_get_generation_config(self) -> dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
if self.generation_config in ("auto", "vllm"):
config = try_get_generation_config(
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
@ -971,17 +971,14 @@ class ModelConfig:
def get_diff_sampling_param(self) -> dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
that differ from the default sampling parameters. If
`generation_config` is `"vllm"`, an empty dictionary is returned.
Returns:
dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
parameters, if `generation_config` is `"vllm"` an empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
if self.generation_config == "vllm":
config = {}
else:
config = self.try_get_generation_config()

View File

@ -207,7 +207,7 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None
generation_config: Optional[str] = "auto"
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False
model_impl: str = "auto"
@ -1034,13 +1034,13 @@ class EngineArgs:
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
default="auto",
help="The folder path to the generation config. "
"Defaults to None, no generation config is loaded, vLLM defaults "
"will be used. If set to 'auto', the generation config will be "
"loaded from model path. If set to a folder path, the generation "
"config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified in generation config, then "
"Defaults to 'auto', the generation config will be loaded from "
"model path. If set to 'vllm', no generation config is loaded, "
"vLLM defaults will be used. If set to a folder path, the "
"generation config will be loaded from the specified folder path. "
"If `max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")

View File

@ -109,8 +109,10 @@ class OpenAIServingChat(OpenAIServing):
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info("Overwriting default chat sampling param with: %s",
self.default_sampling_params)
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params)
async def create_chat_completion(
self,

View File

@ -55,9 +55,10 @@ class OpenAIServingCompletion(OpenAIServing):
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params)
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)
async def create_completion(
self,