Default to generation_config
from model (#12622)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
3b9c6c6947
commit
47512b3200
@ -20,7 +20,7 @@ NUM_CONCURRENT = 500
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
EXPECTED_VALUE = 0.54
|
||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
|
@ -38,6 +38,7 @@ class MockModelConfig:
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
@ -289,7 +289,7 @@ def test_uses_mrope(model_id, uses_mrope):
|
||||
def test_generation_config_loading():
|
||||
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
# When set generation_config to None, the default generation config
|
||||
# When set generation_config to "vllm", the default generation config
|
||||
# will not be loaded.
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
@ -298,7 +298,7 @@ def test_generation_config_loading():
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config=None)
|
||||
generation_config="vllm")
|
||||
assert model_config.get_diff_sampling_param() == {}
|
||||
|
||||
# When set generation_config to "auto", the default generation config
|
||||
@ -340,7 +340,7 @@ def test_generation_config_loading():
|
||||
|
||||
assert model_config.get_diff_sampling_param() == override_result
|
||||
|
||||
# When generation_config is set to None and override_generation_config
|
||||
# When generation_config is set to "vllm" and override_generation_config
|
||||
# is set, the override_generation_config should be used directly.
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
@ -350,7 +350,7 @@ def test_generation_config_loading():
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config=None,
|
||||
generation_config="vllm",
|
||||
override_generation_config=override_generation_config)
|
||||
|
||||
assert model_config.get_diff_sampling_param() == override_generation_config
|
||||
|
@ -255,7 +255,7 @@ class ModelConfig:
|
||||
override_neuron_config: Optional[dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||
logits_processor_pattern: Optional[str] = None,
|
||||
generation_config: Optional[str] = None,
|
||||
generation_config: str = "auto",
|
||||
enable_sleep_mode: bool = False,
|
||||
override_generation_config: Optional[dict[str, Any]] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
@ -951,7 +951,7 @@ class ModelConfig:
|
||||
return self.multimodal_config
|
||||
|
||||
def try_get_generation_config(self) -> dict[str, Any]:
|
||||
if self.generation_config is None or self.generation_config == "auto":
|
||||
if self.generation_config in ("auto", "vllm"):
|
||||
config = try_get_generation_config(
|
||||
self.hf_config_path or self.model,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
@ -971,17 +971,14 @@ class ModelConfig:
|
||||
def get_diff_sampling_param(self) -> dict[str, Any]:
|
||||
"""
|
||||
This method returns a dictionary containing the parameters
|
||||
that differ from the default sampling parameters, but only
|
||||
if `generation_config` is set. If `generation_config` is not
|
||||
set, an empty dictionary is returned.
|
||||
that differ from the default sampling parameters. If
|
||||
`generation_config` is `"vllm"`, an empty dictionary is returned.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary with the differing sampling
|
||||
parameters if `generation_config` is set, otherwise an
|
||||
empty dictionary.
|
||||
parameters, if `generation_config` is `"vllm"` an empty dictionary.
|
||||
"""
|
||||
if self.generation_config is None:
|
||||
# When generation_config is not set
|
||||
if self.generation_config == "vllm":
|
||||
config = {}
|
||||
else:
|
||||
config = self.try_get_generation_config()
|
||||
|
@ -207,7 +207,7 @@ class EngineArgs:
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
generation_config: Optional[str] = None
|
||||
generation_config: Optional[str] = "auto"
|
||||
override_generation_config: Optional[Dict[str, Any]] = None
|
||||
enable_sleep_mode: bool = False
|
||||
model_impl: str = "auto"
|
||||
@ -1034,13 +1034,13 @@ class EngineArgs:
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
default="auto",
|
||||
help="The folder path to the generation config. "
|
||||
"Defaults to None, no generation config is loaded, vLLM defaults "
|
||||
"will be used. If set to 'auto', the generation config will be "
|
||||
"loaded from model path. If set to a folder path, the generation "
|
||||
"config will be loaded from the specified folder path. If "
|
||||
"`max_new_tokens` is specified in generation config, then "
|
||||
"Defaults to 'auto', the generation config will be loaded from "
|
||||
"model path. If set to 'vllm', no generation config is loaded, "
|
||||
"vLLM defaults will be used. If set to a folder path, the "
|
||||
"generation config will be loaded from the specified folder path. "
|
||||
"If `max_new_tokens` is specified in generation config, then "
|
||||
"it sets a server-wide limit on the number of output tokens "
|
||||
"for all requests.")
|
||||
|
||||
|
@ -109,8 +109,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
logger.info("Overwriting default chat sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info("Using default chat sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
|
@ -55,9 +55,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info("Using default completion sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user