Use queue for finished requests (#957)
This commit is contained in:
parent
fbd80ad409
commit
c9927c1a6a
@ -156,8 +156,8 @@ class VllmRunner:
|
|||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts, greedy_params)
|
outputs = self.generate(prompts, greedy_params)
|
||||||
return [(output_ids[0], output_str[0]) for output_ids, output_str in
|
return [(output_ids[0], output_str[0])
|
||||||
outputs]
|
for output_ids, output_str in outputs]
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
|
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -152,7 +152,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
# Request id -> stream.
|
# Request id -> stream.
|
||||||
self.request_streams: Dict[str, AsyncStream] = {}
|
self.request_streams: Dict[str, AsyncStream] = {}
|
||||||
self.finished_requests: Set[str] = set()
|
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
self.background_loop = None
|
self.background_loop = None
|
||||||
if start_engine_loop:
|
if start_engine_loop:
|
||||||
self.start_background_loop()
|
self.start_background_loop()
|
||||||
@ -194,12 +194,14 @@ class AsyncLLMEngine:
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info(f"Finished request {request_id}.")
|
logger.info(f"Finished request {request_id}.")
|
||||||
self.request_streams[request_id].finish()
|
self.request_streams[request_id].finish()
|
||||||
self.finished_requests.add(request_id)
|
self.finished_requests.put_nowait(request_id)
|
||||||
|
|
||||||
await self._engine_abort(self.finished_requests)
|
finished_request = set()
|
||||||
for request_id in self.finished_requests:
|
while not self.finished_requests.empty():
|
||||||
|
finished_request.add(self.finished_requests.get_nowait())
|
||||||
|
await self._engine_abort(finished_request)
|
||||||
|
for request_id in finished_request:
|
||||||
del self.request_streams[request_id]
|
del self.request_streams[request_id]
|
||||||
self.finished_requests.clear()
|
|
||||||
|
|
||||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
@ -226,6 +228,8 @@ class AsyncLLMEngine:
|
|||||||
f"sampling params: {sampling_params}, "
|
f"sampling params: {sampling_params}, "
|
||||||
f"prompt token ids: {prompt_token_ids}.")
|
f"prompt token ids: {prompt_token_ids}.")
|
||||||
|
|
||||||
|
if request_id in self.request_streams:
|
||||||
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
stream = AsyncStream(request_id)
|
stream = AsyncStream(request_id)
|
||||||
self.request_streams[request_id] = stream
|
self.request_streams[request_id] = stream
|
||||||
|
|
||||||
@ -316,7 +320,7 @@ class AsyncLLMEngine:
|
|||||||
logger.info(f"Aborted request {request_id}.")
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
self.request_streams[request_id].finish()
|
self.request_streams[request_id].finish()
|
||||||
self.finished_requests.add(request_id)
|
self.finished_requests.put_nowait(request_id)
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user