[Bugfix] AsyncLLMEngine hangs with asyncio.run (#5654)
This commit is contained in:
parent
d571ca0108
commit
78687504f7
@ -2,8 +2,12 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -94,3 +98,35 @@ async def test_new_requests_event():
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
|
||||
|
||||
def test_asyncio_run():
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m"))
|
||||
|
||||
async def run(prompt: str):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
async for output in engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=prompt):
|
||||
final_output = output
|
||||
return final_output
|
||||
|
||||
async def generate():
|
||||
return await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test1"),
|
||||
)
|
||||
|
||||
results = asyncio.run(generate())
|
||||
assert len(results) == 2
|
||||
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -7,12 +6,6 @@ import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
|
||||
if (not is_hip()):
|
||||
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit)
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@ -26,6 +19,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
from ...conftest import cleanup
|
||||
from ...utils import wait_for_gpu_memory_to_clear
|
||||
|
||||
|
||||
class AsyncLLM:
|
||||
@ -291,38 +285,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
|
||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
threshold_bytes: int,
|
||||
timeout_s: float = 120) -> None:
|
||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||
# context.
|
||||
nvmlInit()
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output: Dict[int, str] = {}
|
||||
output_raw: Dict[int, float] = {}
|
||||
for device in devices:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
gb_used = mem_info.used / 2**30
|
||||
output_raw[device] = gb_used
|
||||
output[device] = f'{gb_used:.02f}'
|
||||
|
||||
print('gpu memory used (GB): ', end='')
|
||||
for k, v in output.items():
|
||||
print(f'{k}={v}; ', end='')
|
||||
print('')
|
||||
|
||||
dur_s = time.time() - start_time
|
||||
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
||||
print(f'Done waiting for free GPU memory on devices {devices=} '
|
||||
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
||||
break
|
||||
|
||||
if dur_s >= timeout_s:
|
||||
raise ValueError(f'Memory of devices {devices=} not free after '
|
||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
|
||||
time.sleep(5)
|
||||
|
@ -4,7 +4,7 @@ import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import ray
|
||||
@ -13,7 +13,11 @@ import requests
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.utils import get_open_port, is_hip
|
||||
|
||||
if (not is_hip()):
|
||||
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit)
|
||||
|
||||
# Path to root of repository so that utilities can be imported by ray workers
|
||||
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
|
||||
@ -154,3 +158,38 @@ def error_on_warning():
|
||||
warnings.simplefilter("error")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
threshold_bytes: int,
|
||||
timeout_s: float = 120) -> None:
|
||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||
# context.
|
||||
nvmlInit()
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output: Dict[int, str] = {}
|
||||
output_raw: Dict[int, float] = {}
|
||||
for device in devices:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
gb_used = mem_info.used / 2**30
|
||||
output_raw[device] = gb_used
|
||||
output[device] = f'{gb_used:.02f}'
|
||||
|
||||
print('gpu memory used (GB): ', end='')
|
||||
for k, v in output.items():
|
||||
print(f'{k}={v}; ', end='')
|
||||
print('')
|
||||
|
||||
dur_s = time.time() - start_time
|
||||
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
||||
print(f'Done waiting for free GPU memory on devices {devices=} '
|
||||
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
||||
break
|
||||
|
||||
if dur_s >= timeout_s:
|
||||
raise ValueError(f'Memory of devices {devices=} not free after '
|
||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
|
||||
time.sleep(5)
|
||||
|
@ -10,6 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
@ -545,8 +546,8 @@ class AsyncLLMEngine:
|
||||
# Abort if iteration takes too long due to unrecoverable errors
|
||||
# (eg. NCCL timeouts).
|
||||
try:
|
||||
has_requests_in_progress = await asyncio.wait_for(
|
||||
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
|
||||
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error(
|
||||
"Engine iteration timed out. This should never happen!")
|
||||
|
189
vllm/engine/async_timeout.py
Normal file
189
vllm/engine/async_timeout.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Workaround for https://github.com/python/cpython/issues/86296
|
||||
#
|
||||
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
# Licensed under the Apache License (Apache-2.0)
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import sys
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from asyncio import timeout as asyncio_timeout
|
||||
else:
|
||||
|
||||
def asyncio_timeout(delay: Optional[float]) -> "Timeout":
|
||||
"""timeout context manager.
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
>>> async with timeout(0.001):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + delay if delay is not None else None
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
class _State(enum.Enum):
|
||||
INIT = "INIT"
|
||||
ENTER = "ENTER"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
EXIT = "EXIT"
|
||||
|
||||
class Timeout:
|
||||
# Internal class, please don't instantiate it directly
|
||||
# Use timeout() and timeout_at() public factories instead.
|
||||
#
|
||||
# Implementation note: `async with timeout()` is preferred
|
||||
# over `with timeout()`.
|
||||
# While technically the Timeout class implementation
|
||||
# doesn't need to be async at all,
|
||||
# the `async with` statement explicitly points that
|
||||
# the context manager should be used from async function context.
|
||||
#
|
||||
# This design allows to avoid many silly misusages.
|
||||
#
|
||||
# TimeoutError is raised immediately when scheduled
|
||||
# if the deadline is passed.
|
||||
# The purpose is to time out as soon as possible
|
||||
# without waiting for the next await expression.
|
||||
|
||||
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
|
||||
|
||||
def __init__(self, deadline: Optional[float],
|
||||
loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._loop = loop
|
||||
self._state = _State.INIT
|
||||
|
||||
self._timeout_handler = None # type: Optional[asyncio.Handle]
|
||||
if deadline is None:
|
||||
self._deadline = None # type: Optional[float]
|
||||
else:
|
||||
self.update(deadline)
|
||||
|
||||
def __enter__(self) -> "Timeout":
|
||||
warnings.warn(
|
||||
"with timeout() is deprecated, use async with timeout()",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state == _State.TIMEOUT
|
||||
|
||||
@property
|
||||
def deadline(self) -> Optional[float]:
|
||||
return self._deadline
|
||||
|
||||
def reject(self) -> None:
|
||||
"""Reject scheduled timeout if any."""
|
||||
# cancel is maybe better name but
|
||||
# task.cancel() raises CancelledError in asyncio world.
|
||||
if self._state not in (_State.INIT, _State.ENTER):
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._reject()
|
||||
|
||||
def _reject(self) -> None:
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
def shift(self, delay: float) -> None:
|
||||
"""Advance timeout on delay seconds.
|
||||
The delay can be negative.
|
||||
Raise RuntimeError if shift is called when deadline is not scheduled
|
||||
"""
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
raise RuntimeError(
|
||||
"cannot shift timeout if deadline is not scheduled")
|
||||
self.update(deadline + delay)
|
||||
|
||||
def update(self, deadline: float) -> None:
|
||||
"""Set deadline to absolute value.
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
If new deadline is in the past the timeout is raised immediately.
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
"""
|
||||
if self._state == _State.EXIT:
|
||||
raise RuntimeError(
|
||||
"cannot reschedule after exit from context manager")
|
||||
if self._state == _State.TIMEOUT:
|
||||
raise RuntimeError("cannot reschedule expired timeout")
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._deadline = deadline
|
||||
if self._state != _State.INIT:
|
||||
self._reschedule()
|
||||
|
||||
def _reschedule(self) -> None:
|
||||
assert self._state == _State.ENTER
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
return
|
||||
|
||||
now = self._loop.time()
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
task = asyncio.current_task()
|
||||
if deadline <= now:
|
||||
self._timeout_handler = self._loop.call_soon(
|
||||
self._on_timeout, task)
|
||||
else:
|
||||
self._timeout_handler = self._loop.call_at(
|
||||
deadline, self._on_timeout, task)
|
||||
|
||||
def _do_enter(self) -> None:
|
||||
if self._state != _State.INIT:
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._state = _State.ENTER
|
||||
self._reschedule()
|
||||
|
||||
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
|
||||
if exc_type is asyncio.CancelledError and \
|
||||
self._state == _State.TIMEOUT:
|
||||
self._timeout_handler = None
|
||||
raise asyncio.TimeoutError
|
||||
# timeout has not expired
|
||||
self._state = _State.EXIT
|
||||
self._reject()
|
||||
return None
|
||||
|
||||
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
|
||||
if task:
|
||||
task.cancel()
|
||||
self._state = _State.TIMEOUT
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
Loading…
x
Reference in New Issue
Block a user