[Frontend] Support override generation config in args (#12409)

Signed-off-by: liuyanyi <wolfsonliu@163.com>
This commit is contained in:
Yanyi Liu 2025-01-29 17:41:01 +08:00 committed by GitHub
parent d93bf4da85
commit ff7424f491
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 100 additions and 8 deletions

View File

@ -281,3 +281,73 @@ def test_uses_mrope(model_id, uses_mrope):
) )
assert config.uses_mrope == uses_mrope assert config.uses_mrope == uses_mrope
def test_generation_config_loading():
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
# When set generation_config to None, the default generation config
# will not be loaded.
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None)
assert model_config.get_diff_sampling_param() == {}
# When set generation_config to "auto", the default generation config
# should be loaded.
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config="auto")
correct_generation_config = {
"repetition_penalty": 1.1,
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
}
assert model_config.get_diff_sampling_param() == correct_generation_config
# The generation config could be overridden by the user.
override_generation_config = {"temperature": 0.5, "top_k": 5}
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config="auto",
override_generation_config=override_generation_config)
override_result = correct_generation_config.copy()
override_result.update(override_generation_config)
assert model_config.get_diff_sampling_param() == override_result
# When generation_config is set to None and override_generation_config
# is set, the override_generation_config should be used directly.
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None,
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config

View File

@ -165,6 +165,8 @@ class ModelConfig:
`logits_processors` extra completion argument. Defaults to None, `logits_processors` extra completion argument. Defaults to None,
which allows no processors. which allows no processors.
generation_config: Configuration parameter file for generation. generation_config: Configuration parameter file for generation.
override_generation_config: Override the generation config with the
given config.
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -225,6 +227,7 @@ class ModelConfig:
logits_processor_pattern: Optional[str] = None, logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None, generation_config: Optional[str] = None,
enable_sleep_mode: bool = False, enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -368,6 +371,7 @@ class ModelConfig:
self.logits_processor_pattern = logits_processor_pattern self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
@ -904,8 +908,13 @@ class ModelConfig:
""" """
if self.generation_config is None: if self.generation_config is None:
# When generation_config is not set # When generation_config is not set
return {} config = {}
else:
config = self.try_get_generation_config() config = self.try_get_generation_config()
# Overriding with given generation config
config.update(self.override_generation_config)
available_params = [ available_params = [
"repetition_penalty", "repetition_penalty",
"temperature", "temperature",

View File

@ -195,6 +195,7 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None generation_config: Optional[str] = None
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False enable_sleep_mode: bool = False
calculate_kv_scales: Optional[bool] = None calculate_kv_scales: Optional[bool] = None
@ -936,12 +937,23 @@ class EngineArgs:
type=nullable_str, type=nullable_str,
default=None, default=None,
help="The folder path to the generation config. " help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. " "Defaults to None, no generation config is loaded, vLLM defaults "
"If set to 'auto', the generation config will be automatically " "will be used. If set to 'auto', the generation config will be "
"loaded from model. If set to a folder path, the generation config " "loaded from model path. If set to a folder path, the generation "
"will be loaded from the specified folder path. If " "config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified, then it sets a server-wide limit " "`max_new_tokens` is specified in generation config, then "
"on the number of output tokens for all requests.") "it sets a server-wide limit on the number of output tokens "
"for all requests.")
parser.add_argument(
"--override-generation-config",
type=json.loads,
default=None,
help="Overrides or sets generation config in JSON format. "
"e.g. ``{\"temperature\": 0.5}``. If used with "
"--generation-config=auto, the override parameters will be merged "
"with the default config from the model. If generation-config is "
"None, only the override parameters are used.")
parser.add_argument("--enable-sleep-mode", parser.add_argument("--enable-sleep-mode",
action="store_true", action="store_true",
@ -1002,6 +1014,7 @@ class EngineArgs:
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern, logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config, generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode, enable_sleep_mode=self.enable_sleep_mode,
) )