[Feature] vLLM CLI (#5090)

Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Ethan Xu 2024-07-14 15:36:43 -07:00 committed by GitHub
parent 73030b7dae
commit dbfe254eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 223 additions and 36 deletions

View File

@ -2,8 +2,8 @@
On the server side, run one of the following commands:
vLLM OpenAI API server
python -m vllm.entrypoints.openai.api_server \
--model <your_model> --swap-space 16 \
vllm serve <your_model> \
--swap-space 16 \
--disable-log-requests
(TGI backend)

View File

@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
```{argparse}
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:func: create_parser_for_docs
:prog: -m vllm.entrypoints.openai.api_server
```

View File

@ -488,4 +488,9 @@ setup(
},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data,
entry_points={
"console_scripts": [
"vllm=vllm.scripts:main",
],
},
)

View File

@ -14,7 +14,7 @@ import requests
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port, is_hip
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
if is_hip():
from amdsmi import (amdsmi_get_gpu_vram_usage,
@ -57,7 +57,9 @@ class RemoteOpenAIServer:
cli_args = cli_args + ["--port", str(get_open_port())]
parser = make_arg_parser()
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
self.port = int(args.port)

View File

@ -8,7 +8,7 @@ from typing import Optional, Set
import fastapi
import uvicorn
from fastapi import Request
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@ -35,10 +35,14 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
@ -64,35 +68,23 @@ async def lifespan(app: fastapi.FastAPI):
yield
app = fastapi.FastAPI(lifespan=lifespan)
def parse_args():
parser = make_arg_parser()
return parser.parse_args()
router = APIRouter()
# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
router.routes.append(route)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@app.get("/health")
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)
@app.post("/tokenize")
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
@ -103,7 +95,7 @@ async def tokenize(request: TokenizeRequest):
return JSONResponse(content=generator.model_dump())
@app.post("/detokenize")
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
@ -114,19 +106,19 @@ async def detokenize(request: DetokenizeRequest):
return JSONResponse(content=generator.model_dump())
@app.get("/v1/models")
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())
@app.get("/version")
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
@ -142,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
@ -156,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
@ -167,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
app.add_middleware(
CORSMiddleware,
@ -178,6 +172,12 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
@ -203,6 +203,12 @@ if __name__ == "__main__":
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
return app
def run_server(args, llm_engine=None):
app = build_app(args)
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
@ -211,10 +217,12 @@ if __name__ == "__main__":
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
global engine, engine_args
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
event_loop: Optional[asyncio.AbstractEventLoop]
try:
@ -230,6 +238,10 @@ if __name__ == "__main__":
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
@ -258,3 +270,13 @@ if __name__ == "__main__":
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
run_server(args)

View File

@ -34,9 +34,7 @@ class PromptAdapterParserAction(argparse.Action):
setattr(namespace, self.dest, adapter_list)
def make_arg_parser():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
@ -133,3 +131,9 @@ def make_arg_parser():
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)

154
vllm/scripts.py Normal file
View File

@ -0,0 +1,154 @@
# The CLI entrypoint to vLLM.
import argparse
import os
import signal
import sys
from typing import Optional
from openai import OpenAI
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
def registrer_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTSTP, signal_handler)
def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
run_server(args)
def interactive_cli(args: argparse.Namespace) -> None:
registrer_signal_handlers()
base_url = args.url
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
openai_client = OpenAI(api_key=api_key, base_url=base_url)
if args.model_name:
model_name = args.model_name
else:
available_models = openai_client.models.list()
model_name = available_models.data[0].id
print(f"Using model: {model_name}")
if args.command == "complete":
complete(model_name, openai_client)
elif args.command == "chat":
chat(args.system_prompt, model_name, openai_client)
def complete(model_name: str, client: OpenAI) -> None:
print("Please enter prompt to complete:")
while True:
input_prompt = input("> ")
completion = client.completions.create(model=model_name,
prompt=input_prompt)
output = completion.choices[0].text
print(output)
def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None:
conversation = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
input_message = input("> ")
message = {"role": "user", "content": input_message}
conversation.append(message)
chat_completion = client.chat.completions.create(model=model_name,
messages=conversation)
response_message = chat_completion.choices[0].message
output = response_message.content
conversation.append(response_message)
print(output)
def _add_query_options(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--url",
type=str,
default="http://localhost:8000/v1",
help="url of the running OpenAI-Compatible RESTful API server")
parser.add_argument(
"--model-name",
type=str,
default=None,
help=("The model name used in prompt completion, default to "
"the first model in list models API call."))
parser.add_argument(
"--api-key",
type=str,
default=None,
help=(
"API key for OpenAI services. If provided, this api key "
"will overwrite the api key obtained through environment variables."
))
return parser
def main():
parser = FlexibleArgumentParser(description="vLLM CLI")
subparsers = parser.add_subparsers(required=True)
serve_parser = subparsers.add_parser(
"serve",
help="Start the vLLM OpenAI Compatible API server",
usage="vllm serve <model_tag> [options]")
serve_parser.add_argument("model_tag",
type=str,
help="The model tag to serve")
serve_parser = make_arg_parser(serve_parser)
serve_parser.set_defaults(dispatch_function=serve)
complete_parser = subparsers.add_parser(
"complete",
help=("Generate text completions based on the given prompt "
"via the running API server"),
usage="vllm complete [options]")
_add_query_options(complete_parser)
complete_parser.set_defaults(dispatch_function=interactive_cli,
command="complete")
chat_parser = subparsers.add_parser(
"chat",
help="Generate chat completions via the running API server",
usage="vllm chat [options]")
_add_query_options(chat_parser)
chat_parser.add_argument(
"--system-prompt",
type=str,
default=None,
help=("The system prompt to be added to the chat template, "
"used for models that support system prompts."))
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
args = parser.parse_args()
# One of the sub commands should be executed.
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
else:
parser.print_help()
if __name__ == "__main__":
main()