Use queue for finished requests (#957)

This commit is contained in:
Antoni Baum 2023-09-05 19:27:23 -07:00 committed by GitHub
parent fbd80ad409
commit c9927c1a6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 9 deletions

View File

@ -156,8 +156,8 @@ class VllmRunner:
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in
outputs]
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_beam_search(
self,

View File

@ -1,7 +1,7 @@
import asyncio
import time
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
@ -152,7 +152,7 @@ class AsyncLLMEngine:
# Request id -> stream.
self.request_streams: Dict[str, AsyncStream] = {}
self.finished_requests: Set[str] = set()
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
self.background_loop = None
if start_engine_loop:
self.start_background_loop()
@ -194,12 +194,14 @@ class AsyncLLMEngine:
if self.log_requests:
logger.info(f"Finished request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.add(request_id)
self.finished_requests.put_nowait(request_id)
await self._engine_abort(self.finished_requests)
for request_id in self.finished_requests:
finished_request = set()
while not self.finished_requests.empty():
finished_request.add(self.finished_requests.get_nowait())
await self._engine_abort(finished_request)
for request_id in finished_request:
del self.request_streams[request_id]
self.finished_requests.clear()
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
@ -226,6 +228,8 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
if request_id in self.request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self.request_streams[request_id] = stream
@ -316,7 +320,7 @@ class AsyncLLMEngine:
logger.info(f"Aborted request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.add(request_id)
self.finished_requests.put_nowait(request_id)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""