vllm/tests/async_engine/test_async_llm_engine.py

218 lines
6.1 KiB
Python
Raw Normal View History

import asyncio
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Optional
import pytest
import pytest_asyncio
import torch
from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear
@dataclass
class RequestOutput:
request_id: int
finished: bool = False
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
async def process_model_inputs_async(self, *args, **kwargs):
pass
async def stop_remote_worker_execution_loop_async(self):
pass
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
print(f'Request calls: {self.add_request_calls}')
async def add_request_async(self, **kwargs):
self.add_request_calls += 1
return
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
def has_unfinished_requests(self):
return self.request_id is not None
def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
return self.request_id is not None
class MockAsyncLLMEngine(AsyncLLMEngine):
2024-09-11 18:23:36 -07:00
_engine_class = MockEngine
@pytest.mark.asyncio
async def test_new_requests_event():
2024-09-11 18:23:36 -07:00
engine = MockAsyncLLMEngine(worker_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request("1", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request("2", "", None)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
assert engine.engine.step_calls >= 3
engine.engine.stop_generating()
await asyncio.sleep(0.001)
old_step_calls = engine.engine.step_calls
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls
await engine.add_request("3", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
2024-09-11 18:23:36 -07:00
engine = MockAsyncLLMEngine(worker_use_ray=True)
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None
def start_engine():
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
@pytest_asyncio.fixture(scope="module")
async def async_engine():
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
func=start_engine)
try:
yield engine
finally:
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the async engine fixture between these tests
return False
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
async def run(prompt: str):
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)
async for output in async_engine.generate(prompt,
sampling_params,
request_id=prompt):
final_output = output
return final_output
results = await asyncio.gather(
run("test0"),
run("test1"),
)
assert len(results) == 2
@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")
assert i == 5
@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
final_output = output
if i == 0:
# wait for generation to complete before consuming
# the remaining messages
await asyncio.sleep(1)
if i < 9:
assert not output.finished
i += 1
assert i == 10
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished