From 8db1b9d0a178c8c04f4e14d994a50e3b88e0b1ae Mon Sep 17 00:00:00 2001 From: Keyun Tong Date: Sat, 22 Feb 2025 05:17:44 -0800 Subject: [PATCH] Support SSL Key Rotation in HTTP Server (#13495) --- requirements-common.txt | 3 +- tests/entrypoints/test_ssl_cert_refresher.py | 72 +++++++++++++++++++ vllm/entrypoints/api_server.py | 6 ++ vllm/entrypoints/launcher.py | 14 +++- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 5 ++ vllm/entrypoints/ssl.py | 74 ++++++++++++++++++++ 7 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/test_ssl_cert_refresher.py create mode 100644 vllm/entrypoints/ssl.py diff --git a/requirements-common.txt b/requirements-common.txt index f72aa40f..c0df136f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,7 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 -lark == 1.2.2 +lark == 1.2.2 xgrammar == 0.1.11; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 @@ -37,3 +37,4 @@ einops # Required for Qwen2-VL. compressed-tensors == 0.9.2 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py +watchfiles # required for http server to monitor the updates of TLS files diff --git a/tests/entrypoints/test_ssl_cert_refresher.py b/tests/entrypoints/test_ssl_cert_refresher.py new file mode 100644 index 00000000..23ce7a67 --- /dev/null +++ b/tests/entrypoints/test_ssl_cert_refresher.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import tempfile +from pathlib import Path +from ssl import SSLContext + +import pytest + +from vllm.entrypoints.ssl import SSLCertRefresher + + +class MockSSLContext(SSLContext): + + def __init__(self): + self.load_cert_chain_count = 0 + self.load_ca_count = 0 + + def load_cert_chain( + self, + certfile, + keyfile=None, + password=None, + ): + self.load_cert_chain_count += 1 + + def load_verify_locations( + self, + cafile=None, + capath=None, + cadata=None, + ): + self.load_ca_count += 1 + + +def create_file() -> str: + with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f: + return f.name + + +def touch_file(path: str) -> None: + Path(path).touch() + + +@pytest.mark.asyncio +async def test_ssl_refresher(): + ssl_context = MockSSLContext() + key_path = create_file() + cert_path = create_file() + ca_path = create_file() + ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 0 + assert ssl_context.load_ca_count == 0 + + touch_file(key_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 1 + assert ssl_context.load_ca_count == 0 + + touch_file(cert_path) + touch_file(ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 2 + assert ssl_context.load_ca_count == 1 + + ssl_refresher.stop() + + touch_file(cert_path) + touch_file(ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 2 + assert ssl_context.load_ca_count == 1 diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 11ffc4f6..28b8c847 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -128,6 +128,7 @@ async def run_server(args: Namespace, shutdown_task = await serve_http( app, sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.log_level, @@ -152,6 +153,11 @@ if __name__ == "__main__": type=str, default=None, help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 79946a49..b09ee526 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,13 +12,16 @@ from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, sock: Optional[socket.socket], +async def serve_http(app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) config = uvicorn.Config(app, **uvicorn_kwargs) + config.load() server = uvicorn.Server(config) _add_shutdown_handlers(app, server) @@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], server_task = loop.create_task( server.serve(sockets=[sock] if sock else None)) + ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs) + def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() async def dummy_shutdown() -> None: pass diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d037a4e6..73061995 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -960,6 +960,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: shutdown_task = await serve_http( app, sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 3054958f..ba953c21 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=nullable_str, default=None, help="The CA certificates file.") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py new file mode 100644 index 00000000..dba916b8 --- /dev/null +++ b/vllm/entrypoints/ssl.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from ssl import SSLContext +from typing import Callable, Optional + +from watchfiles import Change, awatch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SSLCertRefresher: + """A class that monitors SSL certificate files and + reloads them when they change. + """ + + def __init__(self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None) -> None: + self.ssl = ssl_context + self.key_path = key_path + self.cert_path = cert_path + self.ca_path = ca_path + + # Setup certification chain watcher + def update_ssl_cert_chain(change: Change, file_path: str) -> None: + logger.info("Reloading SSL certificate chain") + assert self.key_path and self.cert_path + self.ssl.load_cert_chain(self.cert_path, self.key_path) + + self.watch_ssl_cert_task = None + if self.key_path and self.cert_path: + self.watch_ssl_cert_task = asyncio.create_task( + self._watch_files([self.key_path, self.cert_path], + update_ssl_cert_chain)) + + # Setup CA files watcher + def update_ssl_ca(change: Change, file_path: str) -> None: + logger.info("Reloading SSL CA certificates") + assert self.ca_path + self.ssl.load_verify_locations(self.ca_path) + + self.watch_ssl_ca_task = None + if self.ca_path: + self.watch_ssl_ca_task = asyncio.create_task( + self._watch_files([self.ca_path], update_ssl_ca)) + + async def _watch_files(self, paths, fun: Callable[[Change, str], + None]) -> None: + """Watch multiple file paths asynchronously.""" + logger.info("SSLCertRefresher monitors files: %s", paths) + async for changes in awatch(*paths): + try: + for change, file_path in changes: + logger.info("File change detected: %s - %s", change.name, + file_path) + fun(change, file_path) + except Exception as e: + logger.error( + "SSLCertRefresher failed taking action on file change. " + "Error: %s", e) + + def stop(self) -> None: + """Stop watching files.""" + if self.watch_ssl_cert_task: + self.watch_ssl_cert_task.cancel() + self.watch_ssl_cert_task = None + if self.watch_ssl_ca_task: + self.watch_ssl_ca_task.cancel() + self.watch_ssl_ca_task = None