158 lines
5.1 KiB
Python
158 lines
5.1 KiB
Python
# imports for guided decoding tests
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
import openai # use the official client for correctness check
|
|
import pytest
|
|
# using Ray for overall ease of process management, parallel requests,
|
|
# and debugging.
|
|
import ray
|
|
import requests
|
|
|
|
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
|
# any model with a chat template should work here
|
|
MODEL_NAME = "facebook/opt-125m"
|
|
|
|
|
|
@ray.remote(num_gpus=1)
|
|
class ServerRunner:
|
|
|
|
def __init__(self, args):
|
|
env = os.environ.copy()
|
|
env["PYTHONUNBUFFERED"] = "1"
|
|
self.proc = subprocess.Popen(
|
|
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
|
env=env,
|
|
stdout=sys.stdout,
|
|
stderr=sys.stderr,
|
|
)
|
|
self._wait_for_server()
|
|
|
|
def ready(self):
|
|
return True
|
|
|
|
def _wait_for_server(self):
|
|
# run health check
|
|
start = time.time()
|
|
while True:
|
|
try:
|
|
if requests.get(
|
|
"http://localhost:8000/health").status_code == 200:
|
|
break
|
|
except Exception as err:
|
|
if self.proc.poll() is not None:
|
|
raise RuntimeError("Server exited unexpectedly.") from err
|
|
|
|
time.sleep(0.5)
|
|
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
|
raise RuntimeError(
|
|
"Server failed to start in time.") from err
|
|
|
|
def __del__(self):
|
|
if hasattr(self, "proc"):
|
|
self.proc.terminate()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def server():
|
|
ray.init()
|
|
server_runner = ServerRunner.remote([
|
|
"--model",
|
|
MODEL_NAME,
|
|
# use half precision for speed and memory savings in CI environment
|
|
"--dtype",
|
|
"float16",
|
|
"--max-model-len",
|
|
"2048",
|
|
"--enforce-eager",
|
|
"--engine-use-ray"
|
|
])
|
|
ray.get(server_runner.ready.remote())
|
|
yield server_runner
|
|
ray.shutdown()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def client():
|
|
client = openai.AsyncOpenAI(
|
|
base_url="http://localhost:8000/v1",
|
|
api_key="token-abc123",
|
|
)
|
|
yield client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_models(server, client: openai.AsyncOpenAI):
|
|
models = await client.models.list()
|
|
models = models.data
|
|
served_model = models[0]
|
|
assert served_model.id == MODEL_NAME
|
|
assert all(model.root == MODEL_NAME for model in models)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_completion(server, client: openai.AsyncOpenAI):
|
|
completion = await client.completions.create(model=MODEL_NAME,
|
|
prompt="Hello, my name is",
|
|
max_tokens=5,
|
|
temperature=0.0)
|
|
|
|
assert completion.id is not None
|
|
assert completion.choices is not None and len(completion.choices) == 1
|
|
assert completion.choices[0].text is not None and len(
|
|
completion.choices[0].text) >= 5
|
|
assert completion.choices[0].finish_reason == "length"
|
|
assert completion.usage == openai.types.CompletionUsage(
|
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
|
|
|
# test using token IDs
|
|
completion = await client.completions.create(
|
|
model=MODEL_NAME,
|
|
prompt=[0, 0, 0, 0, 0],
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
)
|
|
assert completion.choices[0].text is not None and len(
|
|
completion.choices[0].text) >= 5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
|
messages = [{
|
|
"role": "system",
|
|
"content": "you are a helpful assistant"
|
|
}, {
|
|
"role": "user",
|
|
"content": "what is 1+1?"
|
|
}]
|
|
|
|
# test single completion
|
|
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
|
messages=messages,
|
|
max_tokens=10,
|
|
logprobs=True,
|
|
top_logprobs=5)
|
|
assert chat_completion.id is not None
|
|
assert chat_completion.choices is not None and len(
|
|
chat_completion.choices) == 1
|
|
assert chat_completion.choices[0].message is not None
|
|
assert chat_completion.choices[0].logprobs is not None
|
|
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
|
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
|
message = chat_completion.choices[0].message
|
|
assert message.content is not None and len(message.content) >= 10
|
|
assert message.role == "assistant"
|
|
messages.append({"role": "assistant", "content": message.content})
|
|
|
|
# test multi-turn dialogue
|
|
messages.append({"role": "user", "content": "express your result in json"})
|
|
chat_completion = await client.chat.completions.create(
|
|
model=MODEL_NAME,
|
|
messages=messages,
|
|
max_tokens=10,
|
|
)
|
|
message = chat_completion.choices[0].message
|
|
assert message.content is not None and len(message.content) >= 0
|