[Feature] specify model in config.yaml (#14855)

Signed-off-by: weizeng <weizeng@roblox.com>
This commit is contained in:
Wei Zeng 2025-03-21 00:26:03 -07:00 committed by GitHub
parent da6ea29f7a
commit 0fa3970deb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 102 additions and 30 deletions

View File

@ -184,6 +184,7 @@ For example:
```yaml
# config.yaml
model: meta-llama/Llama-3.1-8B-Instruct
host: "127.0.0.1"
port: 6379
uvicorn-log-level: "info"
@ -192,12 +193,13 @@ uvicorn-log-level: "info"
To use the above config file:
```bash
vllm serve SOME_MODEL --config config.yaml
vllm serve --config config.yaml
```
:::{note}
In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence.
The order of priorities is `command line > config file values > defaults`.
e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file.
:::
## API Reference

View File

@ -0,0 +1,7 @@
# Same as test_config.yaml but with model specified
model: config-model
port: 12312
served_model_name: mymodel
tensor_parallel_size: 2
trust_remote_code: true
multi_step_stream_outputs: false

View File

@ -1121,3 +1121,15 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "optional" in item.keywords:
item.add_marker(skip_optional)
@pytest.fixture(scope="session")
def cli_config_file():
"""Return the path to the CLI config file."""
return os.path.join(_TEST_DIR, "config", "test_config.yaml")
@pytest.fixture(scope="session")
def cli_config_file_with_model():
"""Return the path to the CLI config file with model."""
return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")

View File

@ -8,7 +8,7 @@ from unittest.mock import patch
import pytest
import torch
from vllm_test_utils import monitor
from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
@ -140,7 +140,8 @@ def parser():
def parser_with_config():
parser = FlexibleArgumentParser()
parser.add_argument('serve')
parser.add_argument('model_tag')
parser.add_argument('model_tag', nargs='?')
parser.add_argument('--model', type=str)
parser.add_argument('--served-model-name', type=str)
parser.add_argument('--config', type=str)
parser.add_argument('--port', type=int)
@ -196,29 +197,29 @@ def test_missing_required_argument(parser):
parser.parse_args([])
def test_cli_override_to_config(parser_with_config):
def test_cli_override_to_config(parser_with_config, cli_config_file):
args = parser_with_config.parse_args([
'serve', 'mymodel', '--config', './data/test_config.yaml',
'serve', 'mymodel', '--config', cli_config_file,
'--tensor-parallel-size', '3'
])
assert args.tensor_parallel_size == 3
args = parser_with_config.parse_args([
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml'
cli_config_file
])
assert args.tensor_parallel_size == 3
assert args.port == 12312
args = parser_with_config.parse_args([
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml', '--port', '666'
cli_config_file, '--port', '666'
])
assert args.tensor_parallel_size == 3
assert args.port == 666
def test_config_args(parser_with_config):
def test_config_args(parser_with_config, cli_config_file):
args = parser_with_config.parse_args(
['serve', 'mymodel', '--config', './data/test_config.yaml'])
['serve', 'mymodel', '--config', cli_config_file])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code
assert not args.multi_step_stream_outputs
@ -240,10 +241,9 @@ def test_config_file(parser_with_config):
])
def test_no_model_tag(parser_with_config):
def test_no_model_tag(parser_with_config, cli_config_file):
with pytest.raises(ValueError):
parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml'])
parser_with_config.parse_args(['serve', '--config', cli_config_file])
# yapf: enable
@ -476,3 +476,34 @@ def test_swap_dict_values(obj, key1, key2):
assert obj[key1] == original_obj[key2]
else:
assert key1 not in obj
def test_model_specification(parser_with_config,
cli_config_file,
cli_config_file_with_model):
# Test model in CLI takes precedence over config
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model
])
assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel'
# Test model from config file works
args = parser_with_config.parse_args([
'serve', '--config', cli_config_file_with_model
])
assert args.model == 'config-model'
assert args.served_model_name == 'mymodel'
# Test no model specified anywhere raises error
with pytest.raises(ValueError, match="No model specified!"):
parser_with_config.parse_args(['serve', '--config', cli_config_file])
# Test other config values are preserved
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model
])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
assert args.multi_step_stream_outputs is False
assert args.port == 12312

View File

@ -21,14 +21,16 @@ class ServeSubcommand(CLISubcommand):
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# The default value of `--model`
if args.model != EngineArgs.model:
raise ValueError(
"With `vllm serve`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# If model is specified in CLI (as positional arg), it takes precedence
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag
# Otherwise use model from config (already in args.model)
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
# Check if we have a model specified somewhere
if args.model == EngineArgs.model: # Still has default value
raise ValueError(
"With `vllm serve`, you should provide the model either as a "
"positional argument or in config file.")
uvloop.run(run_server(args))
@ -41,10 +43,12 @@ class ServeSubcommand(CLISubcommand):
serve_parser = subparsers.add_parser(
"serve",
help="Start the vLLM OpenAI Compatible API server",
usage="vllm serve <model_tag> [options]")
usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag",
type=str,
help="The model tag to serve")
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--config",
type=str,

View File

@ -1264,19 +1264,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args = self._load_config_file(file_path)
# 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0] == "serve":
if index == 1:
model_in_cli = len(args) > 1 and not args[1].startswith('-')
model_in_config = any(arg == '--model' for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model_tag specified! Please check your command-line"
" arguments.")
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
"No model specified! Please specify model either in "
"command-line arguments or in config file.")
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
else:
# No model in CLI, use config if available
args = [args[0]
] + config_args + args[1:index] + args[index + 2:]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
@ -1294,9 +1304,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split('.')[-1]
if extension not in ('yaml', 'yml'):
raise ValueError(
@ -1321,7 +1329,15 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
if isinstance(action, StoreBoolean)
]
# Skip model from config if it's provided as positional argument
skip_model = (hasattr(self, '_parsed_args') and self._parsed_args
and len(self._parsed_args) > 1
and self._parsed_args[0] == 'serve'
and not self._parsed_args[1].startswith('-'))
for key, value in config.items():
if skip_model and key == 'model':
continue
if isinstance(value, bool) and key not in store_boolean_arguments:
if value:
processed_args.append('--' + key)