138 lines
3.7 KiB
Python
138 lines
3.7 KiB
Python
import asyncio
|
|
import contextlib
|
|
import random
|
|
import time
|
|
from typing import Callable
|
|
|
|
import openai
|
|
import pytest
|
|
import pytest_asyncio
|
|
import requests
|
|
|
|
from tests.utils import RemoteOpenAIServer
|
|
|
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server(): # noqa: F811
|
|
args = [
|
|
# use half precision for speed and memory savings in CI environment
|
|
"--dtype",
|
|
"bfloat16",
|
|
"--max-model-len",
|
|
"8192",
|
|
"--enforce-eager",
|
|
"--max-num-seqs",
|
|
"128",
|
|
"--load-format",
|
|
"dummy",
|
|
]
|
|
|
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
|
yield remote_server
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(server):
|
|
async with server.get_async_client() as async_client:
|
|
yield async_client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
ids=["completion", "chat"],
|
|
argnames=["create_func_gen", "content_body"],
|
|
argvalues=[
|
|
(lambda x: x.completions.create, {
|
|
"prompt": " ".join(['A'] * 10_000)
|
|
}),
|
|
(lambda x: x.chat.completions.create, {
|
|
"messages": [{
|
|
"role": "user",
|
|
"content": " ".join(['A'] * 10_000)
|
|
}]
|
|
}),
|
|
],
|
|
)
|
|
async def test_with_and_without_truncate(
|
|
server: RemoteOpenAIServer,
|
|
client: openai.AsyncOpenAI,
|
|
create_func_gen: Callable,
|
|
content_body: dict,
|
|
):
|
|
create_func = create_func_gen(client)
|
|
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
|
|
|
num_requests = 10
|
|
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
|
|
(num_requests - num_requests // 2))
|
|
random.shuffle(truncate_prompt_tokens)
|
|
|
|
bodies = [{
|
|
**body, "extra_body": {
|
|
'truncate_prompt_tokens': t
|
|
}
|
|
} for t in truncate_prompt_tokens]
|
|
|
|
async def get_status_code(**kwargs):
|
|
try:
|
|
await create_func(**kwargs)
|
|
return 200
|
|
except openai.APIStatusError as e:
|
|
return e.status_code
|
|
|
|
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
|
|
assert 500 not in responses
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
ids=["single completion", "multiple completions", "chat"],
|
|
argnames=["create_func_gen", "content_body"],
|
|
argvalues=[
|
|
(lambda x: x.completions.create, {
|
|
"prompt": " ".join(['A'] * 300_000)
|
|
}),
|
|
(lambda x: x.completions.create, {
|
|
"prompt": [" ".join(['A'] * 300_000)] * 2
|
|
}),
|
|
(lambda x: x.chat.completions.create, {
|
|
"messages": [{
|
|
"role": "user",
|
|
"content": " ".join(['A'] * 300_000)
|
|
}]
|
|
}),
|
|
],
|
|
)
|
|
async def test_healthcheck_response_time(
|
|
server: RemoteOpenAIServer,
|
|
client: openai.AsyncOpenAI,
|
|
create_func_gen: Callable,
|
|
content_body: dict,
|
|
):
|
|
num_requests = 50
|
|
|
|
create_func = create_func_gen(client)
|
|
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
|
|
|
def get_response_time(url):
|
|
start_time = time.monotonic()
|
|
res = requests.get(url)
|
|
end_time = time.monotonic()
|
|
assert res.status_code == 200
|
|
return end_time - start_time
|
|
|
|
no_load_response_time = get_response_time(server.url_for("health"))
|
|
tasks = [
|
|
asyncio.create_task(create_func(**body)) for _ in range(num_requests)
|
|
]
|
|
await asyncio.sleep(1) # give the tasks a chance to start running
|
|
load_response_time = get_response_time(server.url_for("health"))
|
|
|
|
with contextlib.suppress(openai.APIStatusError):
|
|
await asyncio.gather(*tasks)
|
|
|
|
assert load_response_time < 100 * no_load_response_time
|
|
assert load_response_time < 0.1
|