Use queue for finished requests (#957)
This commit is contained in:
parent
fbd80ad409
commit
c9927c1a6a
@ -156,8 +156,8 @@ class VllmRunner:
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params)
|
||||
return [(output_ids[0], output_str[0]) for output_ids, output_str in
|
||||
outputs]
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -152,7 +152,7 @@ class AsyncLLMEngine:
|
||||
|
||||
# Request id -> stream.
|
||||
self.request_streams: Dict[str, AsyncStream] = {}
|
||||
self.finished_requests: Set[str] = set()
|
||||
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
@ -194,12 +194,14 @@ class AsyncLLMEngine:
|
||||
if self.log_requests:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.request_streams[request_id].finish()
|
||||
self.finished_requests.add(request_id)
|
||||
self.finished_requests.put_nowait(request_id)
|
||||
|
||||
await self._engine_abort(self.finished_requests)
|
||||
for request_id in self.finished_requests:
|
||||
finished_request = set()
|
||||
while not self.finished_requests.empty():
|
||||
finished_request.add(self.finished_requests.get_nowait())
|
||||
await self._engine_abort(finished_request)
|
||||
for request_id in finished_request:
|
||||
del self.request_streams[request_id]
|
||||
self.finished_requests.clear()
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
@ -226,6 +228,8 @@ class AsyncLLMEngine:
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
if request_id in self.request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
stream = AsyncStream(request_id)
|
||||
self.request_streams[request_id] = stream
|
||||
|
||||
@ -316,7 +320,7 @@ class AsyncLLMEngine:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self.request_streams[request_id].finish()
|
||||
self.finished_requests.add(request_id)
|
||||
self.finished_requests.put_nowait(request_id)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user