2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-12-01 19:01:00 -06:00
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def run_python_script(script_name, timeout):
|
|
|
|
script_name = f'kv_transfer/{script_name}'
|
|
|
|
try:
|
|
|
|
# Start both processes asynchronously using Popen
|
|
|
|
process0 = subprocess.Popen(
|
|
|
|
[sys.executable, script_name],
|
|
|
|
env={"RANK":
|
|
|
|
"0"}, # Set the RANK environment variable for process 0
|
|
|
|
stdout=sys.stdout, # Pipe stdout to current stdout
|
|
|
|
stderr=sys.stderr, # Pipe stderr to current stderr
|
|
|
|
)
|
|
|
|
|
|
|
|
process1 = subprocess.Popen(
|
|
|
|
[sys.executable, script_name],
|
|
|
|
env={"RANK":
|
|
|
|
"1"}, # Set the RANK environment variable for process 1
|
|
|
|
stdout=sys.stdout, # Pipe stdout to current stdout
|
|
|
|
stderr=sys.stderr, # Pipe stderr to current stderr
|
|
|
|
)
|
|
|
|
|
|
|
|
# Wait for both processes to complete, with a timeout
|
|
|
|
process0.wait(timeout=timeout)
|
|
|
|
process1.wait(timeout=timeout)
|
|
|
|
|
|
|
|
# Check the return status of both processes
|
|
|
|
if process0.returncode != 0:
|
|
|
|
pytest.fail(
|
|
|
|
f"Test {script_name} failed for RANK=0, {process0.returncode}")
|
|
|
|
if process1.returncode != 0:
|
|
|
|
pytest.fail(
|
|
|
|
f"Test {script_name} failed for RANK=1, {process1.returncode}")
|
|
|
|
|
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
# If either process times out, terminate both and fail the test
|
|
|
|
process0.terminate()
|
|
|
|
process1.terminate()
|
|
|
|
pytest.fail(f"Test {script_name} timed out")
|
|
|
|
except Exception as e:
|
|
|
|
pytest.fail(f"Test {script_name} failed with error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# Define the test cases using pytest's parametrize
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"script_name,timeout",
|
|
|
|
[
|
|
|
|
("test_lookup_buffer.py",
|
|
|
|
60), # Second test case with a 60-second timeout
|
|
|
|
("test_send_recv.py", 120) # First test case with a 120-second timeout
|
|
|
|
])
|
|
|
|
def test_run_python_script(script_name, timeout):
|
|
|
|
# Check the number of GPUs
|
|
|
|
if torch.cuda.device_count() < 2:
|
|
|
|
pytest.skip(
|
|
|
|
f"Skipping test {script_name} because <2 GPUs are available")
|
|
|
|
|
|
|
|
# Run the test if there are at least 2 GPUs
|
|
|
|
run_python_script(script_name, timeout)
|