[Feature] specify model in config.yaml (#15798)
Signed-off-by: weizeng <weizeng@roblox.com>
This commit is contained in:
parent
8af5a5c4e5
commit
30d6a015e0
@ -188,6 +188,7 @@ For example:
|
|||||||
```yaml
|
```yaml
|
||||||
# config.yaml
|
# config.yaml
|
||||||
|
|
||||||
|
model: meta-llama/Llama-3.1-8B-Instruct
|
||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
port: 6379
|
port: 6379
|
||||||
uvicorn-log-level: "info"
|
uvicorn-log-level: "info"
|
||||||
@ -196,12 +197,13 @@ uvicorn-log-level: "info"
|
|||||||
To use the above config file:
|
To use the above config file:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm serve SOME_MODEL --config config.yaml
|
vllm serve --config config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence.
|
In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence.
|
||||||
The order of priorities is `command line > config file values > defaults`.
|
The order of priorities is `command line > config file values > defaults`.
|
||||||
|
e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## API Reference
|
## API Reference
|
||||||
|
7
tests/config/test_config_with_model.yaml
Normal file
7
tests/config/test_config_with_model.yaml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Same as test_config.yaml but with model specified
|
||||||
|
model: config-model
|
||||||
|
port: 12312
|
||||||
|
served_model_name: mymodel
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
trust_remote_code: true
|
||||||
|
multi_step_stream_outputs: false
|
@ -1117,3 +1117,15 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
for item in items:
|
for item in items:
|
||||||
if "optional" in item.keywords:
|
if "optional" in item.keywords:
|
||||||
item.add_marker(skip_optional)
|
item.add_marker(skip_optional)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def cli_config_file():
|
||||||
|
"""Return the path to the CLI config file."""
|
||||||
|
return os.path.join(_TEST_DIR, "config", "test_config.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def cli_config_file_with_model():
|
||||||
|
"""Return the path to the CLI config file with model."""
|
||||||
|
return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")
|
||||||
|
@ -10,7 +10,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from vllm_test_utils import monitor
|
from vllm_test_utils.monitor import monitor
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
||||||
@ -143,7 +143,8 @@ def parser():
|
|||||||
def parser_with_config():
|
def parser_with_config():
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
parser.add_argument('serve')
|
parser.add_argument('serve')
|
||||||
parser.add_argument('model_tag')
|
parser.add_argument('model_tag', nargs='?')
|
||||||
|
parser.add_argument('--model', type=str)
|
||||||
parser.add_argument('--served-model-name', type=str)
|
parser.add_argument('--served-model-name', type=str)
|
||||||
parser.add_argument('--config', type=str)
|
parser.add_argument('--config', type=str)
|
||||||
parser.add_argument('--port', type=int)
|
parser.add_argument('--port', type=int)
|
||||||
@ -199,29 +200,29 @@ def test_missing_required_argument(parser):
|
|||||||
parser.parse_args([])
|
parser.parse_args([])
|
||||||
|
|
||||||
|
|
||||||
def test_cli_override_to_config(parser_with_config):
|
def test_cli_override_to_config(parser_with_config, cli_config_file):
|
||||||
args = parser_with_config.parse_args([
|
args = parser_with_config.parse_args([
|
||||||
'serve', 'mymodel', '--config', './data/test_config.yaml',
|
'serve', 'mymodel', '--config', cli_config_file,
|
||||||
'--tensor-parallel-size', '3'
|
'--tensor-parallel-size', '3'
|
||||||
])
|
])
|
||||||
assert args.tensor_parallel_size == 3
|
assert args.tensor_parallel_size == 3
|
||||||
args = parser_with_config.parse_args([
|
args = parser_with_config.parse_args([
|
||||||
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
||||||
'./data/test_config.yaml'
|
cli_config_file
|
||||||
])
|
])
|
||||||
assert args.tensor_parallel_size == 3
|
assert args.tensor_parallel_size == 3
|
||||||
assert args.port == 12312
|
assert args.port == 12312
|
||||||
args = parser_with_config.parse_args([
|
args = parser_with_config.parse_args([
|
||||||
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
||||||
'./data/test_config.yaml', '--port', '666'
|
cli_config_file, '--port', '666'
|
||||||
])
|
])
|
||||||
assert args.tensor_parallel_size == 3
|
assert args.tensor_parallel_size == 3
|
||||||
assert args.port == 666
|
assert args.port == 666
|
||||||
|
|
||||||
|
|
||||||
def test_config_args(parser_with_config):
|
def test_config_args(parser_with_config, cli_config_file):
|
||||||
args = parser_with_config.parse_args(
|
args = parser_with_config.parse_args(
|
||||||
['serve', 'mymodel', '--config', './data/test_config.yaml'])
|
['serve', 'mymodel', '--config', cli_config_file])
|
||||||
assert args.tensor_parallel_size == 2
|
assert args.tensor_parallel_size == 2
|
||||||
assert args.trust_remote_code
|
assert args.trust_remote_code
|
||||||
assert not args.multi_step_stream_outputs
|
assert not args.multi_step_stream_outputs
|
||||||
@ -243,10 +244,9 @@ def test_config_file(parser_with_config):
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def test_no_model_tag(parser_with_config):
|
def test_no_model_tag(parser_with_config, cli_config_file):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
parser_with_config.parse_args(
|
parser_with_config.parse_args(['serve', '--config', cli_config_file])
|
||||||
['serve', '--config', './data/test_config.yaml'])
|
|
||||||
|
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -480,6 +480,48 @@ def test_swap_dict_values(obj, key1, key2):
|
|||||||
else:
|
else:
|
||||||
assert key1 not in obj
|
assert key1 not in obj
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_specification(parser_with_config,
|
||||||
|
cli_config_file,
|
||||||
|
cli_config_file_with_model):
|
||||||
|
# Test model in CLI takes precedence over config
|
||||||
|
args = parser_with_config.parse_args([
|
||||||
|
'serve', 'cli-model', '--config', cli_config_file_with_model
|
||||||
|
])
|
||||||
|
assert args.model_tag == 'cli-model'
|
||||||
|
assert args.served_model_name == 'mymodel'
|
||||||
|
|
||||||
|
# Test model from config file works
|
||||||
|
args = parser_with_config.parse_args([
|
||||||
|
'serve', '--config', cli_config_file_with_model,
|
||||||
|
])
|
||||||
|
assert args.model == 'config-model'
|
||||||
|
assert args.served_model_name == 'mymodel'
|
||||||
|
|
||||||
|
# Test no model specified anywhere raises error
|
||||||
|
with pytest.raises(ValueError, match="No model specified!"):
|
||||||
|
parser_with_config.parse_args(['serve', '--config', cli_config_file])
|
||||||
|
|
||||||
|
# Test using --model option raises error
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=(
|
||||||
|
"With `vllm serve`, you should provide the model as a positional "
|
||||||
|
"argument or in a config file instead of via the `--model` option."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||||
|
|
||||||
|
# Test other config values are preserved
|
||||||
|
args = parser_with_config.parse_args([
|
||||||
|
'serve', 'cli-model', '--config', cli_config_file_with_model,
|
||||||
|
])
|
||||||
|
assert args.tensor_parallel_size == 2
|
||||||
|
assert args.trust_remote_code is True
|
||||||
|
assert args.multi_step_stream_outputs is False
|
||||||
|
assert args.port == 12312
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||||
(None, bool, [1, 2, 3])])
|
(None, bool, [1, 2, 3])])
|
||||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||||
|
@ -4,7 +4,6 @@ import argparse
|
|||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
|
||||||
from vllm.entrypoints.cli.types import CLISubcommand
|
from vllm.entrypoints.cli.types import CLISubcommand
|
||||||
from vllm.entrypoints.openai.api_server import run_server
|
from vllm.entrypoints.openai.api_server import run_server
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
@ -21,14 +20,9 @@ class ServeSubcommand(CLISubcommand):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cmd(args: argparse.Namespace) -> None:
|
def cmd(args: argparse.Namespace) -> None:
|
||||||
# The default value of `--model`
|
# If model is specified in CLI (as positional arg), it takes precedence
|
||||||
if args.model != EngineArgs.model:
|
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||||
raise ValueError(
|
args.model = args.model_tag
|
||||||
"With `vllm serve`, you should provide the model as a "
|
|
||||||
"positional argument instead of via the `--model` option.")
|
|
||||||
|
|
||||||
# EngineArgs expects the model name to be passed as --model.
|
|
||||||
args.model = args.model_tag
|
|
||||||
|
|
||||||
uvloop.run(run_server(args))
|
uvloop.run(run_server(args))
|
||||||
|
|
||||||
@ -41,10 +35,12 @@ class ServeSubcommand(CLISubcommand):
|
|||||||
serve_parser = subparsers.add_parser(
|
serve_parser = subparsers.add_parser(
|
||||||
"serve",
|
"serve",
|
||||||
help="Start the vLLM OpenAI Compatible API server",
|
help="Start the vLLM OpenAI Compatible API server",
|
||||||
usage="vllm serve <model_tag> [options]")
|
usage="vllm serve [model_tag] [options]")
|
||||||
serve_parser.add_argument("model_tag",
|
serve_parser.add_argument("model_tag",
|
||||||
type=str,
|
type=str,
|
||||||
help="The model tag to serve")
|
nargs='?',
|
||||||
|
help="The model tag to serve "
|
||||||
|
"(optional if specified in config)")
|
||||||
serve_parser.add_argument(
|
serve_parser.add_argument(
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -1241,6 +1241,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
if args is None:
|
if args is None:
|
||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
# Check for --model in command line arguments first
|
||||||
|
if args and args[0] == "serve":
|
||||||
|
model_in_cli_args = any(arg == '--model' for arg in args)
|
||||||
|
|
||||||
|
if model_in_cli_args:
|
||||||
|
raise ValueError(
|
||||||
|
"With `vllm serve`, you should provide the model as a "
|
||||||
|
"positional argument or in a config file instead of via "
|
||||||
|
"the `--model` option.")
|
||||||
|
|
||||||
if '--config' in args:
|
if '--config' in args:
|
||||||
args = self._pull_args_from_config(args)
|
args = self._pull_args_from_config(args)
|
||||||
|
|
||||||
@ -1324,19 +1334,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
config_args = self._load_config_file(file_path)
|
config_args = self._load_config_file(file_path)
|
||||||
|
|
||||||
# 0th index is for {serve,chat,complete}
|
# 0th index is for {serve,chat,complete}
|
||||||
# followed by model_tag (only for serve)
|
# optionally followed by model_tag (only for serve)
|
||||||
# followed by config args
|
# followed by config args
|
||||||
# followed by rest of cli args.
|
# followed by rest of cli args.
|
||||||
# maintaining this order will enforce the precedence
|
# maintaining this order will enforce the precedence
|
||||||
# of cli > config > defaults
|
# of cli > config > defaults
|
||||||
if args[0] == "serve":
|
if args[0] == "serve":
|
||||||
if index == 1:
|
model_in_cli = len(args) > 1 and not args[1].startswith('-')
|
||||||
|
model_in_config = any(arg == '--model' for arg in config_args)
|
||||||
|
|
||||||
|
if not model_in_cli and not model_in_config:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No model_tag specified! Please check your command-line"
|
"No model specified! Please specify model either "
|
||||||
" arguments.")
|
"as a positional argument or in a config file.")
|
||||||
args = [args[0]] + [
|
|
||||||
args[1]
|
if model_in_cli:
|
||||||
] + config_args + args[2:index] + args[index + 2:]
|
# Model specified as positional arg, keep CLI version
|
||||||
|
args = [args[0]] + [
|
||||||
|
args[1]
|
||||||
|
] + config_args + args[2:index] + args[index + 2:]
|
||||||
|
else:
|
||||||
|
# No model in CLI, use config if available
|
||||||
|
args = [args[0]
|
||||||
|
] + config_args + args[1:index] + args[index + 2:]
|
||||||
else:
|
else:
|
||||||
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
|
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
|
||||||
|
|
||||||
@ -1354,9 +1374,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
'--port': '12323',
|
'--port': '12323',
|
||||||
'--tensor-parallel-size': '4'
|
'--tensor-parallel-size': '4'
|
||||||
]
|
]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
extension: str = file_path.split('.')[-1]
|
extension: str = file_path.split('.')[-1]
|
||||||
if extension not in ('yaml', 'yml'):
|
if extension not in ('yaml', 'yml'):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user