vllm/cacheflow/server/async_llm_server.py

153 lines
6.1 KiB
Python
Raw Normal View History

2023-05-20 13:06:59 -07:00
import asyncio
import time
2023-05-23 21:39:50 -07:00
from typing import Dict, Optional
2023-05-20 13:06:59 -07:00
from cacheflow.logger import init_logger
2023-05-20 13:06:59 -07:00
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
2023-05-20 13:06:59 -07:00
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import ray, initialize_cluster
logger = init_logger(__name__)
2023-05-20 13:06:59 -07:00
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
2023-05-23 21:39:50 -07:00
class AsyncLLMServer:
2023-05-20 13:06:59 -07:00
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
*args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray
if not self.server_use_ray:
server_class = LLMServer
elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMServer).remote
2023-05-20 13:06:59 -07:00
else:
server_class = ray.remote(num_gpus=1)(LLMServer).remote
self.server = server_class(*args, **kwargs)
2023-05-20 13:06:59 -07:00
# Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {}
self.is_server_running = False
self.kicking_request_id: Optional[str] = None
2023-05-20 13:06:59 -07:00
async def server_step(self, kicking_request_id: Optional[str] = None):
2023-05-20 13:06:59 -07:00
self.is_server_running = True
self.kicking_request_id = kicking_request_id
if self.server_use_ray:
request_outputs = await self.server.step.remote()
else:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# requests into the queue.
await asyncio.sleep(0)
request_outputs = self.server.step()
2023-05-20 13:06:59 -07:00
self.is_server_running = False
self.kicking_request_id = None
2023-05-20 13:06:59 -07:00
# Notify the waiting coroutines that there are new outputs ready.
for request_output in request_outputs:
request_id = request_output.request_id
self.request_outputs[request_id] = request_output
self.request_events[request_id].set()
2023-05-23 21:39:50 -07:00
async def generate(self, prompt: str, sampling_params: SamplingParams,
request_id: str) -> RequestOutput:
2023-05-20 13:06:59 -07:00
# Preprocess the request.
arrival_time = time.time()
# Create an event to notify us that there is new output from the
# cacheflow server.
request_event = asyncio.Event()
self.request_events[request_id] = request_event
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}.")
2023-05-20 13:06:59 -07:00
# Add the request into the cacheflow server's waiting queue.
if self.server_use_ray:
await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
else:
self.server.add_request(
request_id, prompt, sampling_params, arrival_time=arrival_time)
2023-05-20 13:06:59 -07:00
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step(request_id)
2023-05-20 13:06:59 -07:00
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try:
await asyncio.wait_for(request_event.wait(),
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
except asyncio.TimeoutError:
continue
# Reset the event to wait for the next output.
request_event.clear()
# Decode and return new outputs.
request_output = self.request_outputs[request_id]
2023-05-23 21:39:50 -07:00
yield request_output
2023-05-20 13:06:59 -07:00
# Once finished, release the resources of the sequence group.
2023-05-23 21:39:50 -07:00
if request_output.finished():
logger.info(f"Finished request {request_id}.")
2023-05-20 13:06:59 -07:00
del self.request_outputs[request_id]
del self.request_events[request_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
break
async def abort(self, request_id: str) -> None:
if request_id not in self.request_events:
# The request has already finished or been aborted.
return
logger.info(f"Aborted request {request_id}.")
if self.server_use_ray:
await self.server.abort_request.remote(request_id)
else:
self.server.abort_request(request_id)
if request_id in self.request_events:
del self.request_events[request_id]
if request_id in self.request_outputs:
del self.request_outputs[request_id]
# To prevent deadlock when a request is aborted while the server is
# running.
if self.kicking_request_id == request_id:
self.is_server_running = False
self.kicking_request_id = None
2023-05-23 21:39:50 -07:00
@classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
2023-05-23 21:39:50 -07:00
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(
parallel_config, server_args.server_use_ray)
2023-05-23 21:39:50 -07:00
# Create the LLM server.
server = cls(server_args.worker_use_ray,
server_args.server_use_ray,
*server_configs,
2023-05-23 21:39:50 -07:00
distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server