2024-08-14 13:44:27 -03:00
|
|
|
import os
|
2023-09-07 13:43:45 -07:00
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
from multiprocessing import Pool
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
2024-01-04 03:30:22 +08:00
|
|
|
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
|
2023-09-07 13:43:45 -07:00
|
|
|
response = requests.post("http://localhost:8000/generate",
|
|
|
|
json={
|
|
|
|
"prompt": prompt,
|
2024-01-04 03:30:22 +08:00
|
|
|
"max_tokens": max_tokens,
|
2023-09-07 13:43:45 -07:00
|
|
|
"temperature": 0,
|
|
|
|
"ignore_eos": True
|
|
|
|
})
|
|
|
|
response.raise_for_status()
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
2024-01-04 03:30:22 +08:00
|
|
|
def _query_server_long(prompt: str) -> dict:
|
|
|
|
return _query_server(prompt, max_tokens=500)
|
|
|
|
|
|
|
|
|
2023-09-07 13:43:45 -07:00
|
|
|
@pytest.fixture
|
2024-04-16 14:24:53 +09:00
|
|
|
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
|
|
|
worker_use_ray: bool):
|
2023-09-07 13:43:45 -07:00
|
|
|
script_path = Path(__file__).parent.joinpath(
|
|
|
|
"api_server_async_engine.py").absolute()
|
2024-04-16 14:24:53 +09:00
|
|
|
commands = [
|
2024-03-15 16:37:01 -07:00
|
|
|
sys.executable, "-u",
|
|
|
|
str(script_path), "--model", "facebook/opt-125m", "--host",
|
|
|
|
"127.0.0.1", "--tokenizer-pool-size",
|
|
|
|
str(tokenizer_pool_size)
|
2024-04-16 14:24:53 +09:00
|
|
|
]
|
2024-08-14 13:44:27 -03:00
|
|
|
|
|
|
|
# Copy the environment variables and append `VLLM_ALLOW_ENGINE_USE_RAY=1`
|
|
|
|
# to prevent `--engine-use-ray` raises an exception due to it deprecation
|
|
|
|
env_vars = os.environ.copy()
|
|
|
|
env_vars["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
|
|
|
|
2024-04-16 14:24:53 +09:00
|
|
|
if engine_use_ray:
|
|
|
|
commands.append("--engine-use-ray")
|
|
|
|
if worker_use_ray:
|
|
|
|
commands.append("--worker-use-ray")
|
2024-08-14 13:44:27 -03:00
|
|
|
uvicorn_process = subprocess.Popen(commands, env=env_vars)
|
2023-09-07 13:43:45 -07:00
|
|
|
yield
|
|
|
|
uvicorn_process.terminate()
|
|
|
|
|
|
|
|
|
2024-03-15 16:37:01 -07:00
|
|
|
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
2024-04-16 14:24:53 +09:00
|
|
|
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
|
|
|
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
|
|
|
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
|
|
|
engine_use_ray: bool):
|
2023-09-07 13:43:45 -07:00
|
|
|
"""
|
|
|
|
Run the API server and test it.
|
|
|
|
|
|
|
|
We run both the server and requests in separate processes.
|
|
|
|
|
|
|
|
We test that the server can handle incoming requests, including
|
|
|
|
multiple requests at the same time, and that it can handle requests
|
|
|
|
being cancelled without crashing.
|
|
|
|
"""
|
|
|
|
with Pool(32) as pool:
|
|
|
|
# Wait until the server is ready
|
2023-12-27 02:37:06 +08:00
|
|
|
prompts = ["warm up"] * 1
|
2023-09-07 13:43:45 -07:00
|
|
|
result = None
|
|
|
|
while not result:
|
|
|
|
try:
|
2023-12-27 02:37:06 +08:00
|
|
|
for r in pool.map(_query_server, prompts):
|
|
|
|
result = r
|
2023-09-07 13:43:45 -07:00
|
|
|
break
|
2023-12-27 02:37:06 +08:00
|
|
|
except requests.exceptions.ConnectionError:
|
2023-09-07 13:43:45 -07:00
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
# Actual tests start here
|
|
|
|
# Try with 1 prompt
|
|
|
|
for result in pool.map(_query_server, prompts):
|
|
|
|
assert result
|
|
|
|
|
|
|
|
num_aborted_requests = requests.get(
|
|
|
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
|
|
|
assert num_aborted_requests == 0
|
|
|
|
|
|
|
|
# Try with 100 prompts
|
2023-12-27 02:37:06 +08:00
|
|
|
prompts = ["test prompt"] * 100
|
2023-09-07 13:43:45 -07:00
|
|
|
for result in pool.map(_query_server, prompts):
|
|
|
|
assert result
|
|
|
|
|
2024-01-04 03:30:22 +08:00
|
|
|
with Pool(32) as pool:
|
2023-09-07 13:43:45 -07:00
|
|
|
# Cancel requests
|
2023-12-27 02:37:06 +08:00
|
|
|
prompts = ["canceled requests"] * 100
|
2024-01-04 03:30:22 +08:00
|
|
|
pool.map_async(_query_server_long, prompts)
|
|
|
|
time.sleep(0.01)
|
2023-09-07 13:43:45 -07:00
|
|
|
pool.terminate()
|
|
|
|
pool.join()
|
|
|
|
|
|
|
|
# check cancellation stats
|
2024-01-14 12:37:58 -08:00
|
|
|
# give it some times to update the stats
|
|
|
|
time.sleep(1)
|
|
|
|
|
2023-09-07 13:43:45 -07:00
|
|
|
num_aborted_requests = requests.get(
|
|
|
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
|
|
|
assert num_aborted_requests > 0
|
|
|
|
|
|
|
|
# check that server still runs after cancellations
|
|
|
|
with Pool(32) as pool:
|
|
|
|
# Try with 100 prompts
|
2023-12-27 02:37:06 +08:00
|
|
|
prompts = ["test prompt after canceled"] * 100
|
2023-09-07 13:43:45 -07:00
|
|
|
for result in pool.map(_query_server, prompts):
|
|
|
|
assert result
|