[Feature] specify model in config.yaml (#15798)

Signed-off-by: weizeng <weizeng@roblox.com>
This commit is contained in:
Wei Zeng 2025-04-01 01:20:06 -07:00 committed by GitHub
parent 8af5a5c4e5
commit 30d6a015e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 109 additions and 32 deletions

View File

@ -188,6 +188,7 @@ For example:
```yaml ```yaml
# config.yaml # config.yaml
model: meta-llama/Llama-3.1-8B-Instruct
host: "127.0.0.1" host: "127.0.0.1"
port: 6379 port: 6379
uvicorn-log-level: "info" uvicorn-log-level: "info"
@ -196,12 +197,13 @@ uvicorn-log-level: "info"
To use the above config file: To use the above config file:
```bash ```bash
vllm serve SOME_MODEL --config config.yaml vllm serve --config config.yaml
``` ```
:::{note} :::{note}
In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence. In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence.
The order of priorities is `command line > config file values > defaults`. The order of priorities is `command line > config file values > defaults`.
e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file.
::: :::
## API Reference ## API Reference

View File

@ -0,0 +1,7 @@
# Same as test_config.yaml but with model specified
model: config-model
port: 12312
served_model_name: mymodel
tensor_parallel_size: 2
trust_remote_code: true
multi_step_stream_outputs: false

View File

@ -1117,3 +1117,15 @@ def pytest_collection_modifyitems(config, items):
for item in items: for item in items:
if "optional" in item.keywords: if "optional" in item.keywords:
item.add_marker(skip_optional) item.add_marker(skip_optional)
@pytest.fixture(scope="session")
def cli_config_file():
"""Return the path to the CLI config file."""
return os.path.join(_TEST_DIR, "config", "test_config.yaml")
@pytest.fixture(scope="session")
def cli_config_file_with_model():
"""Return the path to the CLI config file with model."""
return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")

View File

@ -10,7 +10,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm_test_utils import monitor from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
@ -143,7 +143,8 @@ def parser():
def parser_with_config(): def parser_with_config():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument('serve') parser.add_argument('serve')
parser.add_argument('model_tag') parser.add_argument('model_tag', nargs='?')
parser.add_argument('--model', type=str)
parser.add_argument('--served-model-name', type=str) parser.add_argument('--served-model-name', type=str)
parser.add_argument('--config', type=str) parser.add_argument('--config', type=str)
parser.add_argument('--port', type=int) parser.add_argument('--port', type=int)
@ -199,29 +200,29 @@ def test_missing_required_argument(parser):
parser.parse_args([]) parser.parse_args([])
def test_cli_override_to_config(parser_with_config): def test_cli_override_to_config(parser_with_config, cli_config_file):
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', 'mymodel', '--config', './data/test_config.yaml', 'serve', 'mymodel', '--config', cli_config_file,
'--tensor-parallel-size', '3' '--tensor-parallel-size', '3'
]) ])
assert args.tensor_parallel_size == 3 assert args.tensor_parallel_size == 3
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml' cli_config_file
]) ])
assert args.tensor_parallel_size == 3 assert args.tensor_parallel_size == 3
assert args.port == 12312 assert args.port == 12312
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml', '--port', '666' cli_config_file, '--port', '666'
]) ])
assert args.tensor_parallel_size == 3 assert args.tensor_parallel_size == 3
assert args.port == 666 assert args.port == 666
def test_config_args(parser_with_config): def test_config_args(parser_with_config, cli_config_file):
args = parser_with_config.parse_args( args = parser_with_config.parse_args(
['serve', 'mymodel', '--config', './data/test_config.yaml']) ['serve', 'mymodel', '--config', cli_config_file])
assert args.tensor_parallel_size == 2 assert args.tensor_parallel_size == 2
assert args.trust_remote_code assert args.trust_remote_code
assert not args.multi_step_stream_outputs assert not args.multi_step_stream_outputs
@ -243,10 +244,9 @@ def test_config_file(parser_with_config):
]) ])
def test_no_model_tag(parser_with_config): def test_no_model_tag(parser_with_config, cli_config_file):
with pytest.raises(ValueError): with pytest.raises(ValueError):
parser_with_config.parse_args( parser_with_config.parse_args(['serve', '--config', cli_config_file])
['serve', '--config', './data/test_config.yaml'])
# yapf: enable # yapf: enable
@ -480,6 +480,48 @@ def test_swap_dict_values(obj, key1, key2):
else: else:
assert key1 not in obj assert key1 not in obj
def test_model_specification(parser_with_config,
cli_config_file,
cli_config_file_with_model):
# Test model in CLI takes precedence over config
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model
])
assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel'
# Test model from config file works
args = parser_with_config.parse_args([
'serve', '--config', cli_config_file_with_model,
])
assert args.model == 'config-model'
assert args.served_model_name == 'mymodel'
# Test no model specified anywhere raises error
with pytest.raises(ValueError, match="No model specified!"):
parser_with_config.parse_args(['serve', '--config', cli_config_file])
# Test using --model option raises error
with pytest.raises(
ValueError,
match=(
"With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."
),
):
parser_with_config.parse_args(['serve', '--model', 'my-model'])
# Test other config values are preserved
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model,
])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
assert args.multi_step_stream_outputs is False
assert args.port == 12312
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ), @pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])]) (None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2]) @pytest.mark.parametrize("output", [0, 1, 2])

View File

@ -4,7 +4,6 @@ import argparse
import uvloop import uvloop
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@ -21,14 +20,9 @@ class ServeSubcommand(CLISubcommand):
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:
# The default value of `--model` # If model is specified in CLI (as positional arg), it takes precedence
if args.model != EngineArgs.model: if hasattr(args, 'model_tag') and args.model_tag is not None:
raise ValueError( args.model = args.model_tag
"With `vllm serve`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
uvloop.run(run_server(args)) uvloop.run(run_server(args))
@ -41,10 +35,12 @@ class ServeSubcommand(CLISubcommand):
serve_parser = subparsers.add_parser( serve_parser = subparsers.add_parser(
"serve", "serve",
help="Start the vLLM OpenAI Compatible API server", help="Start the vLLM OpenAI Compatible API server",
usage="vllm serve <model_tag> [options]") usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag", serve_parser.add_argument("model_tag",
type=str, type=str,
help="The model tag to serve") nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument( serve_parser.add_argument(
"--config", "--config",
type=str, type=str,

View File

@ -1241,6 +1241,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
model_in_cli_args = any(arg == '--model' for arg in args)
if model_in_cli_args:
raise ValueError(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option.")
if '--config' in args: if '--config' in args:
args = self._pull_args_from_config(args) args = self._pull_args_from_config(args)
@ -1324,19 +1334,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args = self._load_config_file(file_path) config_args = self._load_config_file(file_path)
# 0th index is for {serve,chat,complete} # 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve) # optionally followed by model_tag (only for serve)
# followed by config args # followed by config args
# followed by rest of cli args. # followed by rest of cli args.
# maintaining this order will enforce the precedence # maintaining this order will enforce the precedence
# of cli > config > defaults # of cli > config > defaults
if args[0] == "serve": if args[0] == "serve":
if index == 1: model_in_cli = len(args) > 1 and not args[1].startswith('-')
model_in_config = any(arg == '--model' for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError( raise ValueError(
"No model_tag specified! Please check your command-line" "No model specified! Please specify model either "
" arguments.") "as a positional argument or in a config file.")
args = [args[0]] + [
args[1] if model_in_cli:
] + config_args + args[2:index] + args[index + 2:] # Model specified as positional arg, keep CLI version
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
else:
# No model in CLI, use config if available
args = [args[0]
] + config_args + args[1:index] + args[index + 2:]
else: else:
args = [args[0]] + config_args + args[1:index] + args[index + 2:] args = [args[0]] + config_args + args[1:index] + args[index + 2:]
@ -1354,9 +1374,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323', '--port': '12323',
'--tensor-parallel-size': '4' '--tensor-parallel-size': '4'
] ]
""" """
extension: str = file_path.split('.')[-1] extension: str = file_path.split('.')[-1]
if extension not in ('yaml', 'yml'): if extension not in ('yaml', 'yml'):
raise ValueError( raise ValueError(