120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from subprocess import Popen
|
|
|
|
import pytest
|
|
import requests
|
|
import torch
|
|
|
|
|
|
# Fixture to set up environment variables and teardown servers after tests
|
|
@pytest.fixture(scope="module", autouse=True)
|
|
def setup_servers():
|
|
if torch.cuda.device_count() < 4:
|
|
pytest.skip("Skipping test: fewer than 4 GPUs available")
|
|
|
|
# Set up environment variables
|
|
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
|
|
shell=True).decode().strip()
|
|
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
|
|
|
|
# Start prefill instance
|
|
prefill_cmd = [
|
|
sys.executable,
|
|
"-m",
|
|
"vllm.entrypoints.openai.api_server",
|
|
"--model",
|
|
"meta-llama/Llama-3.2-1B-Instruct",
|
|
"--port",
|
|
"8100",
|
|
"--gpu-memory-utilization",
|
|
"0.5",
|
|
"--max-model-len",
|
|
"1000",
|
|
"--kv-transfer-config",
|
|
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
|
|
'"kv_rank":0,"kv_parallel_size":2}',
|
|
]
|
|
prefill_env = os.environ.copy()
|
|
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
|
|
prefill_proc = Popen(prefill_cmd, env=prefill_env)
|
|
|
|
# Start decode instance
|
|
decode_cmd = [
|
|
sys.executable,
|
|
"-m",
|
|
"vllm.entrypoints.openai.api_server",
|
|
"--model",
|
|
"meta-llama/Llama-3.2-1B-Instruct",
|
|
"--port",
|
|
"8200",
|
|
"--gpu-memory-utilization",
|
|
"0.5",
|
|
"--max-model-len",
|
|
"1000",
|
|
"--kv-transfer-config",
|
|
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
|
|
'"kv_rank":1,"kv_parallel_size":2}',
|
|
]
|
|
decode_env = os.environ.copy()
|
|
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
|
|
decode_proc = Popen(decode_cmd, env=decode_env)
|
|
|
|
# Wait for servers to be ready
|
|
assert wait_for_server(8100), "Prefill server did not start in time"
|
|
assert wait_for_server(8200), "Decode server did not start in time"
|
|
|
|
# Yield to the test function and handle teardown after tests
|
|
yield
|
|
|
|
# Cleanup: kill the processes
|
|
prefill_proc.terminate()
|
|
decode_proc.terminate()
|
|
|
|
# Additional cleanup if needed
|
|
prefill_proc.wait()
|
|
decode_proc.wait()
|
|
|
|
|
|
# Helper function to wait for server
|
|
def wait_for_server(port, timeout=240):
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout:
|
|
try:
|
|
response = requests.get(f"http://localhost:{port}/v1/completions")
|
|
if response.status_code in [200, 405]:
|
|
return True
|
|
except requests.ConnectionError:
|
|
time.sleep(1)
|
|
return False
|
|
|
|
|
|
# Test function to send curl requests and validate responses
|
|
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
|
|
def test_disaggregated_prefilling(prompt):
|
|
# Send to prefill
|
|
response = requests.post("http://localhost:8100/v1/completions",
|
|
headers={"Content-Type": "application/json"},
|
|
json={
|
|
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
|
"prompt": prompt,
|
|
"max_tokens": 1,
|
|
"temperature": 0
|
|
})
|
|
assert response.status_code == 200
|
|
|
|
# Send to decode
|
|
response = requests.post("http://localhost:8200/v1/completions",
|
|
headers={"Content-Type": "application/json"},
|
|
json={
|
|
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
|
"prompt": prompt,
|
|
"max_tokens": 10,
|
|
"temperature": 0
|
|
})
|
|
assert response.status_code == 200
|