[V1][Frontend] Coalesce bunched RequestOutput
s (#12298)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
c5cffcd0cd
commit
24b0205f58
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
@ -6,6 +7,7 @@ import pytest
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
@ -18,28 +20,39 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM, request_id: str,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int) -> Tuple[int, str]:
|
||||
count = 0
|
||||
async for _ in engine.generate(request_id=request_id,
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
output_kind=output_kind,
|
||||
temperature=0)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=max_tokens, temperature=0)):
|
||||
sampling_params=sampling_params):
|
||||
|
||||
num_tokens = len(out.outputs[0].token_ids)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
else:
|
||||
count = num_tokens
|
||||
|
||||
count += 1
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(monkeypatch):
|
||||
async def test_load(monkeypatch, output_kind: RequestOutputKind):
|
||||
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
|
||||
# so that in the future when we switch, we don't have to change all the
|
||||
# tests.
|
||||
with monkeypatch.context() as m:
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 10000
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
@ -51,26 +64,33 @@ async def test_load(monkeypatch):
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||
generate(engine, request_id, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
for task in tasks:
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort(monkeypatch):
|
||||
async def test_abort(monkeypatch, output_kind: RequestOutputKind):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 100
|
||||
@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||
generate(engine, request_id, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# API server cancels requests when they disconnect.
|
||||
for idx in REQUEST_IDS_TO_ABORT:
|
||||
@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
|
||||
# Confirm we can do another generation.
|
||||
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
||||
task = asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS))
|
||||
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
engine.shutdown()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Generic, List, Optional
|
||||
from typing import Dict, Generic, List, MutableSequence, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
@ -162,6 +162,26 @@ class RequestOutput:
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
"""Merge subsequent RequestOutput into this one"""
|
||||
|
||||
self.prompt = next_output.prompt
|
||||
self.prompt_token_ids = next_output.prompt_token_ids
|
||||
self.prompt_logprobs = next_output.prompt_logprobs
|
||||
self.finished |= next_output.finished
|
||||
|
||||
#TODO assuming n == 1 for now
|
||||
completion = self.outputs[0]
|
||||
next_completion = next_output.outputs[0]
|
||||
completion.text += next_completion.text
|
||||
if not isinstance(completion.token_ids, MutableSequence):
|
||||
completion.token_ids = list(completion.token_ids)
|
||||
completion.token_ids.extend(next_completion.token_ids)
|
||||
if next_completion.logprobs:
|
||||
assert completion.logprobs is not None
|
||||
completion.logprobs.extend(next_completion.logprobs)
|
||||
completion.cumulative_logprob = next_completion.cumulative_logprob
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
cls, seq_group: SequenceGroup, use_cache: bool,
|
||||
|
@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -214,6 +214,14 @@ class AsyncLLM(EngineClient):
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() if not q.empty() else await q.get()
|
||||
|
||||
# Coalesce any additional queued outputs
|
||||
while not q.empty():
|
||||
next_out = q.get_nowait()
|
||||
if sampling_params.output_kind == RequestOutputKind.DELTA:
|
||||
out.add(next_out)
|
||||
else:
|
||||
out = next_out
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
|
Loading…
x
Reference in New Issue
Block a user