vllm/tests/async_engine/test_async_llm_engine.py
2024-09-26 20:35:15 -07:00

375 lines
11 KiB
Python

import asyncio
import os
import uuid
from asyncio import CancelledError
from copy import copy
from dataclasses import dataclass
from typing import List, Optional
import pytest
import pytest_asyncio
import torch
from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from vllm.sampling_params import RequestOutputKind
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear
@dataclass
class RequestOutput:
request_id: int
finished: bool = False
@dataclass
class MockModelConfig:
use_async_output_proc = True
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
self.model_config = MockModelConfig()
async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
async def process_model_inputs_async(self, *args, **kwargs):
pass
async def stop_remote_worker_execution_loop_async(self):
pass
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
print(f'Request calls: {self.add_request_calls}')
async def add_request_async(self, **kwargs):
self.add_request_calls += 1
return
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
def has_unfinished_requests(self):
return self.request_id is not None
def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
return self.request_id is not None
class MockAsyncLLMEngine(AsyncLLMEngine):
_engine_class = MockEngine
@pytest.mark.asyncio
async def test_new_requests_event():
params = SamplingParams()
engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request("1", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request("2", "", params)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
assert engine.engine.step_calls >= 3
engine.engine.stop_generating()
await asyncio.sleep(0.001)
old_step_calls = engine.engine.step_calls
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls
await engine.add_request("3", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine()
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None
def start_engine():
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m",
enforce_eager=True,
num_scheduler_steps=num_scheduler_steps))
def uid() -> str:
return str(uuid.uuid4())
@pytest_asyncio.fixture(scope="module")
async def async_engine():
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
func=start_engine)
try:
yield engine
finally:
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the async engine fixture between these tests
return False
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
async def run(prompt: str):
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
sampling_params,
request_id=uid()):
output_count += 1
final_output = output
return final_output, output_count
results = await asyncio.gather(
run("test0"),
run("test0"),
)
assert len(results) == 2
first, second = results
# remove nondeterministic fields for comparison
first[0].metrics = None
second[0].metrics = None
first[0].request_id = None
second[0].request_id = None
assert str(first) == str(second)
output_count = results[0][1]
if num_scheduler_steps == 1:
assert output_count == 32
else:
assert 1 < output_count < 32
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)
async def run(prompt: str, kind: RequestOutputKind):
params = copy(sampling_params)
params.output_kind = kind
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
output_count += 1
final_output = output
assert final_output is not None
assert final_output.finished
return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)
async def run_deltas(prompt: str):
params = copy(sampling_params)
params.output_kind = RequestOutputKind.DELTA
prompt_tokens = None
output_tokens: List[int] = []
output_text = ""
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text
final_output = output
# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert stop or text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
prompt_tokens = output.prompt_token_ids
output_tokens.extend(token_ids)
output_text += text
output_count += 1
assert final_output is not None
assert final_output.finished
return prompt_tokens, output_tokens, output_text, output_count
results = await asyncio.gather(
run("common input prompt", RequestOutputKind.CUMULATIVE),
run("common input prompt", RequestOutputKind.FINAL_ONLY),
run_deltas("common input prompt"))
# Make sure outputs are the same
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
assert len(prompt_set) == 1
text_set = set(text for _, _, text, _ in results)
assert len(text_set) == 1
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
assert len(tokens_set) == 1
cumulative, final, deltas = results
# output message counts
assert cumulative[3] == deltas[3]
if num_scheduler_steps == 1:
assert cumulative[3] == 32
else:
assert 1 < cumulative[3] < 32
assert final[3] == 1
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
sampling_params = SamplingParams(
temperature=0,
min_tokens=13,
max_tokens=13,
stop=stop,
)
stop_at = 5 if num_scheduler_steps == 1 else 1
request_id = uid()
i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id=request_id):
assert not output.finished
i += 1
if i == stop_at:
await async_engine.abort(request_id)
assert i == stop_at
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
stop=stop,
)
stream = async_engine.generate("test3", sampling_params, request_id=uid())
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
final_output = output
if i == 0:
# wait for generation to complete before consuming
# the remaining messages
await asyncio.sleep(1)
if i < 9:
assert not output.finished
i += 1
assert i == 10
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished