Use queue for finished requests (#957)

This commit is contained in:
Antoni Baum 2023-09-05 19:27:23 -07:00 committed by GitHub
parent fbd80ad409
commit c9927c1a6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 9 deletions

View File

@ -156,8 +156,8 @@ class VllmRunner:
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params) outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in return [(output_ids[0], output_str[0])
outputs] for output_ids, output_str in outputs]
def generate_beam_search( def generate_beam_search(
self, self,

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union from typing import Any, Dict, Iterable, List, Optional, Type, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -152,7 +152,7 @@ class AsyncLLMEngine:
# Request id -> stream. # Request id -> stream.
self.request_streams: Dict[str, AsyncStream] = {} self.request_streams: Dict[str, AsyncStream] = {}
self.finished_requests: Set[str] = set() self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
self.background_loop = None self.background_loop = None
if start_engine_loop: if start_engine_loop:
self.start_background_loop() self.start_background_loop()
@ -194,12 +194,14 @@ class AsyncLLMEngine:
if self.log_requests: if self.log_requests:
logger.info(f"Finished request {request_id}.") logger.info(f"Finished request {request_id}.")
self.request_streams[request_id].finish() self.request_streams[request_id].finish()
self.finished_requests.add(request_id) self.finished_requests.put_nowait(request_id)
await self._engine_abort(self.finished_requests) finished_request = set()
for request_id in self.finished_requests: while not self.finished_requests.empty():
finished_request.add(self.finished_requests.get_nowait())
await self._engine_abort(finished_request)
for request_id in finished_request:
del self.request_streams[request_id] del self.request_streams[request_id]
self.finished_requests.clear()
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
@ -226,6 +228,8 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, " f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {prompt_token_ids}.")
if request_id in self.request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id) stream = AsyncStream(request_id)
self.request_streams[request_id] = stream self.request_streams[request_id] = stream
@ -316,7 +320,7 @@ class AsyncLLMEngine:
logger.info(f"Aborted request {request_id}.") logger.info(f"Aborted request {request_id}.")
self.request_streams[request_id].finish() self.request_streams[request_id].finish()
self.finished_requests.add(request_id) self.finished_requests.put_nowait(request_id)
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""