[Frontend] Add Early Validation For Chat Template / Tool Call Parser (#9151)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
a3691b6b5e
commit
069d3bd8d0
@ -1,37 +1,43 @@
|
|||||||
import json
|
import json
|
||||||
import unittest
|
|
||||||
|
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
|
validate_parsed_serve_args)
|
||||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
from ...utils import VLLM_PATH
|
||||||
|
|
||||||
LORA_MODULE = {
|
LORA_MODULE = {
|
||||||
"name": "module2",
|
"name": "module2",
|
||||||
"path": "/path/to/module2",
|
"path": "/path/to/module2",
|
||||||
"base_model_name": "llama"
|
"base_model_name": "llama"
|
||||||
}
|
}
|
||||||
|
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
|
assert CHATML_JINJA_PATH.exists()
|
||||||
|
|
||||||
|
|
||||||
class TestLoraParserAction(unittest.TestCase):
|
@pytest.fixture
|
||||||
|
def serve_parser():
|
||||||
|
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||||
|
return make_arg_parser(parser)
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# Setting up argparse parser for tests
|
|
||||||
parser = FlexibleArgumentParser(
|
|
||||||
description="vLLM's remote OpenAI server.")
|
|
||||||
self.parser = make_arg_parser(parser)
|
|
||||||
|
|
||||||
def test_valid_key_value_format(self):
|
### Tests for Lora module parsing
|
||||||
|
def test_valid_key_value_format(serve_parser):
|
||||||
# Test old format: name=path
|
# Test old format: name=path
|
||||||
args = self.parser.parse_args([
|
args = serve_parser.parse_args([
|
||||||
'--lora-modules',
|
'--lora-modules',
|
||||||
'module1=/path/to/module1',
|
'module1=/path/to/module1',
|
||||||
])
|
])
|
||||||
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
|
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
|
||||||
self.assertEqual(args.lora_modules, expected)
|
assert args.lora_modules == expected
|
||||||
|
|
||||||
def test_valid_json_format(self):
|
|
||||||
|
def test_valid_json_format(serve_parser):
|
||||||
# Test valid JSON format input
|
# Test valid JSON format input
|
||||||
args = self.parser.parse_args([
|
args = serve_parser.parse_args([
|
||||||
'--lora-modules',
|
'--lora-modules',
|
||||||
json.dumps(LORA_MODULE),
|
json.dumps(LORA_MODULE),
|
||||||
])
|
])
|
||||||
@ -40,40 +46,44 @@ class TestLoraParserAction(unittest.TestCase):
|
|||||||
path='/path/to/module2',
|
path='/path/to/module2',
|
||||||
base_model_name='llama')
|
base_model_name='llama')
|
||||||
]
|
]
|
||||||
self.assertEqual(args.lora_modules, expected)
|
assert args.lora_modules == expected
|
||||||
|
|
||||||
def test_invalid_json_format(self):
|
|
||||||
|
def test_invalid_json_format(serve_parser):
|
||||||
# Test invalid JSON format input, missing closing brace
|
# Test invalid JSON format input, missing closing brace
|
||||||
with self.assertRaises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
self.parser.parse_args([
|
serve_parser.parse_args([
|
||||||
'--lora-modules',
|
'--lora-modules', '{"name": "module3", "path": "/path/to/module3"'
|
||||||
'{"name": "module3", "path": "/path/to/module3"'
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_invalid_type_error(self):
|
|
||||||
|
def test_invalid_type_error(serve_parser):
|
||||||
# Test type error when values are not JSON or key=value
|
# Test type error when values are not JSON or key=value
|
||||||
with self.assertRaises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
self.parser.parse_args([
|
serve_parser.parse_args([
|
||||||
'--lora-modules',
|
'--lora-modules',
|
||||||
'invalid_format' # This is not JSON or key=value format
|
'invalid_format' # This is not JSON or key=value format
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_invalid_json_field(self):
|
|
||||||
|
def test_invalid_json_field(serve_parser):
|
||||||
# Test valid JSON format but missing required fields
|
# Test valid JSON format but missing required fields
|
||||||
with self.assertRaises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
self.parser.parse_args([
|
serve_parser.parse_args([
|
||||||
'--lora-modules',
|
'--lora-modules',
|
||||||
'{"name": "module4"}' # Missing required 'path' field
|
'{"name": "module4"}' # Missing required 'path' field
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_empty_values(self):
|
|
||||||
# Test when no LoRA modules are provided
|
|
||||||
args = self.parser.parse_args(['--lora-modules', ''])
|
|
||||||
self.assertEqual(args.lora_modules, [])
|
|
||||||
|
|
||||||
def test_multiple_valid_inputs(self):
|
def test_empty_values(serve_parser):
|
||||||
|
# Test when no LoRA modules are provided
|
||||||
|
args = serve_parser.parse_args(['--lora-modules', ''])
|
||||||
|
assert args.lora_modules == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_valid_inputs(serve_parser):
|
||||||
# Test multiple valid inputs (both old and JSON format)
|
# Test multiple valid inputs (both old and JSON format)
|
||||||
args = self.parser.parse_args([
|
args = serve_parser.parse_args([
|
||||||
'--lora-modules',
|
'--lora-modules',
|
||||||
'module1=/path/to/module1',
|
'module1=/path/to/module1',
|
||||||
json.dumps(LORA_MODULE),
|
json.dumps(LORA_MODULE),
|
||||||
@ -84,8 +94,38 @@ class TestLoraParserAction(unittest.TestCase):
|
|||||||
path='/path/to/module2',
|
path='/path/to/module2',
|
||||||
base_model_name='llama')
|
base_model_name='llama')
|
||||||
]
|
]
|
||||||
self.assertEqual(args.lora_modules, expected)
|
assert args.lora_modules == expected
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
### Tests for serve argument validation that run prior to loading
|
||||||
unittest.main()
|
def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
|
||||||
|
"""Ensure validation fails if tool choice is enabled with no call parser"""
|
||||||
|
# If we enable-auto-tool-choice, explode with no tool-call-parser
|
||||||
|
args = serve_parser.parse_args(args=["--enable-auto-tool-choice"])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
|
||||||
|
"""Ensure validation passes with tool choice enabled with a call parser"""
|
||||||
|
args = serve_parser.parse_args(args=[
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
"--tool-call-parser",
|
||||||
|
"mistral",
|
||||||
|
])
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_template_validation_for_happy_paths(serve_parser):
|
||||||
|
"""Ensure validation passes if the chat template exists"""
|
||||||
|
args = serve_parser.parse_args(
|
||||||
|
args=["--chat-template",
|
||||||
|
CHATML_JINJA_PATH.absolute().as_posix()])
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_template_validation_for_sad_paths(serve_parser):
|
||||||
|
"""Ensure validation fails if the chat template doesn't exist"""
|
||||||
|
args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
@ -303,6 +303,28 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||||
|
"""Raises if the provided chat template appears invalid."""
|
||||||
|
if chat_template is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
elif isinstance(chat_template, Path) and not chat_template.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"the supplied chat template path doesn't exist")
|
||||||
|
|
||||||
|
elif isinstance(chat_template, str):
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template
|
||||||
|
for c in JINJA_CHARS) and not Path(chat_template).exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"The supplied chat template string ({chat_template}) "
|
||||||
|
f"appears path-like, but doesn't exist!")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{type(chat_template)} is not a valid chat template type")
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template(
|
def load_chat_template(
|
||||||
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
|
@ -31,7 +31,8 @@ from vllm.engine.multiprocessing.engine import run_mp_engine
|
|||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
|
validate_parsed_serve_args)
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
@ -577,5 +578,6 @@ if __name__ == "__main__":
|
|||||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
parser = make_arg_parser(parser)
|
parser = make_arg_parser(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
uvloop.run(run_server(args))
|
uvloop.run(run_server(args))
|
||||||
|
@ -10,6 +10,7 @@ import ssl
|
|||||||
from typing import List, Optional, Sequence, Union
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
|
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
@ -231,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||||
|
"""Quick checks for model serve args that raise prior to loading."""
|
||||||
|
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure that the chat template is valid; raises if it likely isn't
|
||||||
|
validate_chat_template(args.chat_template)
|
||||||
|
|
||||||
|
# Enable auto tool needs a tool call parser to be valid
|
||||||
|
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||||
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
|
"--tool-call-parser")
|
||||||
|
|
||||||
|
|
||||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||||
parser_for_docs = FlexibleArgumentParser(
|
parser_for_docs = FlexibleArgumentParser(
|
||||||
prog="-m vllm.entrypoints.openai.api_server")
|
prog="-m vllm.entrypoints.openai.api_server")
|
||||||
|
@ -11,7 +11,8 @@ from openai.types.chat import ChatCompletionMessageParam
|
|||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.entrypoints.openai.api_server import run_server
|
from vllm.entrypoints.openai.api_server import run_server
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
|
validate_parsed_serve_args)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -142,7 +143,7 @@ def main():
|
|||||||
env_setup()
|
env_setup()
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(description="vLLM CLI")
|
parser = FlexibleArgumentParser(description="vLLM CLI")
|
||||||
subparsers = parser.add_subparsers(required=True)
|
subparsers = parser.add_subparsers(required=True, dest="subparser")
|
||||||
|
|
||||||
serve_parser = subparsers.add_parser(
|
serve_parser = subparsers.add_parser(
|
||||||
"serve",
|
"serve",
|
||||||
@ -186,6 +187,9 @@ def main():
|
|||||||
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
|
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if args.subparser == "serve":
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
# One of the sub commands should be executed.
|
# One of the sub commands should be executed.
|
||||||
if hasattr(args, "dispatch_function"):
|
if hasattr(args, "dispatch_function"):
|
||||||
args.dispatch_function(args)
|
args.dispatch_function(args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user