
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
179 lines
5.2 KiB
Python
179 lines
5.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from functools import partial
|
|
from time import sleep
|
|
from typing import Any, List, Tuple
|
|
|
|
import pytest
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
|
ResultHandler, WorkerMonitor)
|
|
from vllm.worker.worker_base import WorkerWrapperBase
|
|
|
|
|
|
class DummyWorkerWrapper(WorkerWrapperBase):
|
|
"""Dummy version of vllm.worker.worker.Worker"""
|
|
|
|
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
|
|
sleep(0.05)
|
|
|
|
if isinstance(worker_input, Exception):
|
|
# simulate error case
|
|
raise worker_input
|
|
|
|
return self.rpc_rank, input
|
|
|
|
|
|
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
|
result_handler = ResultHandler()
|
|
vllm_config = VllmConfig()
|
|
workers = [
|
|
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
|
|
rank) for rank in range(8)
|
|
]
|
|
|
|
worker_monitor = WorkerMonitor(workers, result_handler)
|
|
assert not worker_monitor.is_alive()
|
|
|
|
result_handler.start()
|
|
worker_monitor.start()
|
|
assert worker_monitor.is_alive()
|
|
|
|
return workers, worker_monitor
|
|
|
|
|
|
def test_local_workers() -> None:
|
|
"""Test workers with sync task submission"""
|
|
|
|
workers, worker_monitor = _start_workers()
|
|
|
|
def execute_workers(worker_input: str) -> None:
|
|
worker_outputs = [
|
|
worker.execute_method("worker_method", worker_input)
|
|
for worker in workers
|
|
]
|
|
|
|
for rank, output in enumerate(worker_outputs):
|
|
assert output.get() == (rank, input)
|
|
|
|
executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
# Test concurrent submission from different threads
|
|
futures = [
|
|
executor.submit(partial(execute_workers, f"thread {thread_num}"))
|
|
for thread_num in range(4)
|
|
]
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
# Test error case
|
|
exception = ValueError("fake error")
|
|
result = workers[0].execute_method("worker_method", exception)
|
|
try:
|
|
result.get()
|
|
pytest.fail("task should have failed")
|
|
except Exception as e:
|
|
assert isinstance(e, ValueError)
|
|
assert str(e) == "fake error"
|
|
|
|
# Test cleanup when a worker fails
|
|
assert worker_monitor.is_alive()
|
|
workers[3].process.kill()
|
|
|
|
# Other workers should get shut down here
|
|
worker_monitor.join(20)
|
|
|
|
# Ensure everything is stopped
|
|
assert not worker_monitor.is_alive()
|
|
assert all(not worker.process.is_alive() for worker in workers)
|
|
|
|
# Further attempts to submit tasks should fail
|
|
try:
|
|
_result = workers[0].execute_method("worker_method", "test")
|
|
pytest.fail("task should fail once workers have been shut down")
|
|
except Exception as e:
|
|
assert isinstance(e, ChildProcessError)
|
|
|
|
|
|
def test_local_workers_clean_shutdown() -> None:
|
|
"""Test clean shutdown"""
|
|
|
|
workers, worker_monitor = _start_workers()
|
|
|
|
assert worker_monitor.is_alive()
|
|
assert all(worker.process.is_alive() for worker in workers)
|
|
|
|
# Clean shutdown
|
|
worker_monitor.close()
|
|
|
|
worker_monitor.join(20)
|
|
|
|
# Ensure everything is stopped
|
|
assert not worker_monitor.is_alive()
|
|
assert all(not worker.process.is_alive() for worker in workers)
|
|
|
|
# Further attempts to submit tasks should fail
|
|
try:
|
|
_result = workers[0].execute_method("worker_method", "test")
|
|
pytest.fail("task should fail once workers have been shut down")
|
|
except Exception as e:
|
|
assert isinstance(e, ChildProcessError)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_local_workers_async() -> None:
|
|
"""Test local workers with async task submission"""
|
|
|
|
workers, worker_monitor = _start_workers()
|
|
|
|
async def execute_workers(worker_input: str) -> None:
|
|
worker_coros = [
|
|
worker.execute_method_async("worker_method", worker_input)
|
|
for worker in workers
|
|
]
|
|
|
|
results = await asyncio.gather(*worker_coros)
|
|
for rank, result in enumerate(results):
|
|
assert result == (rank, input)
|
|
|
|
tasks = [
|
|
asyncio.create_task(execute_workers(f"task {task_num}"))
|
|
for task_num in range(4)
|
|
]
|
|
|
|
for task in tasks:
|
|
await task
|
|
|
|
# Test error case
|
|
exception = ValueError("fake error")
|
|
try:
|
|
_result = await workers[0].execute_method_async(
|
|
"worker_method", exception)
|
|
pytest.fail("task should have failed")
|
|
except Exception as e:
|
|
assert isinstance(e, ValueError)
|
|
assert str(e) == "fake error"
|
|
|
|
# Test cleanup when a worker fails
|
|
assert worker_monitor.is_alive()
|
|
workers[3].process.kill()
|
|
|
|
# Other workers should get shut down here
|
|
worker_monitor.join(20)
|
|
|
|
# Ensure everything is stopped
|
|
assert not worker_monitor.is_alive()
|
|
assert all(not worker.process.is_alive() for worker in workers)
|
|
|
|
# Further attempts to submit tasks should fail
|
|
try:
|
|
_result = await workers[0].execute_method_async(
|
|
"worker_method", "test")
|
|
pytest.fail("task should fail once workers have been shut down")
|
|
except Exception as e:
|
|
assert isinstance(e, ChildProcessError)
|