[V1][Frontend] Coalesce bunched RequestOutputs (#12298)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Nick Hill 2025-01-23 17:17:41 -08:00 committed by GitHub
parent c5cffcd0cd
commit 24b0205f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 18 deletions

View File

@ -1,4 +1,5 @@
import asyncio
from contextlib import ExitStack
from typing import List, Tuple
import pytest
@ -6,6 +7,7 @@ import pytest
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
if not current_platform.is_cuda():
@ -18,28 +20,39 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
async def generate(engine: AsyncLLM, request_id: str,
output_kind: RequestOutputKind,
max_tokens: int) -> Tuple[int, str]:
count = 0
async for _ in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(
max_tokens=max_tokens, temperature=0)):
sampling_params = SamplingParams(max_tokens=max_tokens,
output_kind=output_kind,
temperature=0)
async for out in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=sampling_params):
num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens
count += 1
await asyncio.sleep(0.)
return count, request_id
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(monkeypatch):
async def test_load(monkeypatch, output_kind: RequestOutputKind):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
after.callback(engine.shutdown)
NUM_REQUESTS = 10000
NUM_EXPECTED_TOKENS = 10
@ -51,26 +64,33 @@ async def test_load(monkeypatch):
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
generate(engine, request_id, output_kind,
NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
assert not engine.output_processor.has_unfinished_requests()
engine.shutdown()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort(monkeypatch):
async def test_abort(monkeypatch, output_kind: RequestOutputKind):
with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 100
@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
generate(engine, request_id, output_kind,
NUM_EXPECTED_TOKENS)))
# API server cancels requests when they disconnect.
for idx in REQUEST_IDS_TO_ABORT:
@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
# Confirm we can do another generation.
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
task = asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS))
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()
engine.shutdown()

View File

@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, Generic, List, Optional
from typing import Dict, Generic, List, MutableSequence, Optional
from typing import Sequence as GenericSequence
from typing import Union
@ -162,6 +162,26 @@ class RequestOutput:
finished=finished,
)
def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""
self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids
self.prompt_logprobs = next_output.prompt_logprobs
self.finished |= next_output.finished
#TODO assuming n == 1 for now
completion = self.outputs[0]
next_completion = next_output.outputs[0]
completion.text += next_completion.text
if not isinstance(completion.token_ids, MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(next_completion.logprobs)
completion.cumulative_logprob = next_completion.cumulative_logprob
@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,

View File

@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
@ -214,6 +214,14 @@ class AsyncLLM(EngineClient):
# task switching under load which helps performance).
out = q.get_nowait() if not q.empty() else await q.get()
# Coalesce any additional queued outputs
while not q.empty():
next_out = q.get_nowait()
if sampling_params.output_kind == RequestOutputKind.DELTA:
out.add(next_out)
else:
out = next_out
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished