vllm/tests/utils.py
Cyrus Leung 5ae5ed1e60
[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-05-28 13:29:31 -07:00

104 lines
2.8 KiB
Python

import os
import subprocess
import sys
import time
import warnings
from contextlib import contextmanager
import ray
import requests
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_open_port
# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
@ray.remote(num_gpus=1)
class ServerRunner:
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > self.MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
def init_test_distributed_environment(
tp_size: int,
pp_size: int,
rank: int,
distributed_init_port: str,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank)
ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_tensor_parallel(
tp_size: int,
pp_size: int,
test_target,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray.init(runtime_env={"working_dir": VLLM_PATH})
distributed_init_port = get_open_port()
refs = []
for rank in range(tp_size * pp_size):
refs.append(
test_target.remote(tp_size, pp_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()
@contextmanager
def error_on_warning():
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
yield