diff --git a/tests/test_config.py b/tests/test_config.py index 4518adfc..ec366b93 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -281,3 +281,73 @@ def test_uses_mrope(model_id, uses_mrope): ) assert config.uses_mrope == uses_mrope + + +def test_generation_config_loading(): + model_id = "Qwen/Qwen2.5-1.5B-Instruct" + + # When set generation_config to None, the default generation config + # will not be loaded. + model_config = ModelConfig(model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + generation_config=None) + assert model_config.get_diff_sampling_param() == {} + + # When set generation_config to "auto", the default generation config + # should be loaded. + model_config = ModelConfig(model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + generation_config="auto") + + correct_generation_config = { + "repetition_penalty": 1.1, + "temperature": 0.7, + "top_p": 0.8, + "top_k": 20, + } + + assert model_config.get_diff_sampling_param() == correct_generation_config + + # The generation config could be overridden by the user. + override_generation_config = {"temperature": 0.5, "top_k": 5} + + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + generation_config="auto", + override_generation_config=override_generation_config) + + override_result = correct_generation_config.copy() + override_result.update(override_generation_config) + + assert model_config.get_diff_sampling_param() == override_result + + # When generation_config is set to None and override_generation_config + # is set, the override_generation_config should be used directly. + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + generation_config=None, + override_generation_config=override_generation_config) + + assert model_config.get_diff_sampling_param() == override_generation_config diff --git a/vllm/config.py b/vllm/config.py index d7c9311a..58464eae 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -165,6 +165,8 @@ class ModelConfig: `logits_processors` extra completion argument. Defaults to None, which allows no processors. generation_config: Configuration parameter file for generation. + override_generation_config: Override the generation config with the + given config. """ def compute_hash(self) -> str: @@ -225,6 +227,7 @@ class ModelConfig: logits_processor_pattern: Optional[str] = None, generation_config: Optional[str] = None, enable_sleep_mode: bool = False, + override_generation_config: Optional[Dict[str, Any]] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -368,6 +371,7 @@ class ModelConfig: self.logits_processor_pattern = logits_processor_pattern self.generation_config = generation_config + self.override_generation_config = override_generation_config or {} self._verify_quantization() self._verify_cuda_graph() @@ -904,8 +908,13 @@ class ModelConfig: """ if self.generation_config is None: # When generation_config is not set - return {} - config = self.try_get_generation_config() + config = {} + else: + config = self.try_get_generation_config() + + # Overriding with given generation config + config.update(self.override_generation_config) + available_params = [ "repetition_penalty", "temperature", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba96484e..1f203b6e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -195,6 +195,7 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None generation_config: Optional[str] = None + override_generation_config: Optional[Dict[str, Any]] = None enable_sleep_mode: bool = False calculate_kv_scales: Optional[bool] = None @@ -936,12 +937,23 @@ class EngineArgs: type=nullable_str, default=None, help="The folder path to the generation config. " - "Defaults to None, will use the default generation config in vLLM. " - "If set to 'auto', the generation config will be automatically " - "loaded from model. If set to a folder path, the generation config " - "will be loaded from the specified folder path. If " - "`max_new_tokens` is specified, then it sets a server-wide limit " - "on the number of output tokens for all requests.") + "Defaults to None, no generation config is loaded, vLLM defaults " + "will be used. If set to 'auto', the generation config will be " + "loaded from model path. If set to a folder path, the generation " + "config will be loaded from the specified folder path. If " + "`max_new_tokens` is specified in generation config, then " + "it sets a server-wide limit on the number of output tokens " + "for all requests.") + + parser.add_argument( + "--override-generation-config", + type=json.loads, + default=None, + help="Overrides or sets generation config in JSON format. " + "e.g. ``{\"temperature\": 0.5}``. If used with " + "--generation-config=auto, the override parameters will be merged " + "with the default config from the model. If generation-config is " + "None, only the override parameters are used.") parser.add_argument("--enable-sleep-mode", action="store_true", @@ -1002,6 +1014,7 @@ class EngineArgs: override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config, + override_generation_config=self.override_generation_config, enable_sleep_mode=self.enable_sleep_mode, )