vllm/tests/entrypoints/openai/test_models.py

70 lines
1.9 KiB
Python
Raw Normal View History

2024-03-25 23:59:47 +09:00
import openai # use the official client for correctness check
2024-01-17 05:33:14 +00:00
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
2024-01-17 05:33:14 +00:00
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def server(zephyr_lora_files, ray_ctx):
return RemoteOpenAIServer([
2024-01-17 05:33:14 +00:00
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
2024-01-17 05:33:14 +00:00
"--dtype",
"bfloat16",
2024-01-17 05:33:14 +00:00
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128",
2024-01-17 05:33:14 +00:00
])
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
2024-01-17 05:33:14 +00:00
@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"