[V1][Frontend] Coalesce bunched RequestOutput
s (#12298)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
c5cffcd0cd
commit
24b0205f58
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import ExitStack
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -6,6 +7,7 @@ import pytest
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
@ -18,28 +20,39 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
|||||||
|
|
||||||
|
|
||||||
async def generate(engine: AsyncLLM, request_id: str,
|
async def generate(engine: AsyncLLM, request_id: str,
|
||||||
|
output_kind: RequestOutputKind,
|
||||||
max_tokens: int) -> Tuple[int, str]:
|
max_tokens: int) -> Tuple[int, str]:
|
||||||
count = 0
|
count = 0
|
||||||
async for _ in engine.generate(request_id=request_id,
|
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||||
prompt="Hello my name is Robert and",
|
output_kind=output_kind,
|
||||||
sampling_params=SamplingParams(
|
temperature=0)
|
||||||
max_tokens=max_tokens, temperature=0)):
|
async for out in engine.generate(request_id=request_id,
|
||||||
|
prompt="Hello my name is Robert and",
|
||||||
|
sampling_params=sampling_params):
|
||||||
|
|
||||||
|
num_tokens = len(out.outputs[0].token_ids)
|
||||||
|
if output_kind == RequestOutputKind.DELTA:
|
||||||
|
count += num_tokens
|
||||||
|
else:
|
||||||
|
count = num_tokens
|
||||||
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.)
|
await asyncio.sleep(0.)
|
||||||
|
|
||||||
return count, request_id
|
return count, request_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load(monkeypatch):
|
async def test_load(monkeypatch, output_kind: RequestOutputKind):
|
||||||
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
|
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
|
||||||
# so that in the future when we switch, we don't have to change all the
|
# so that in the future when we switch, we don't have to change all the
|
||||||
# tests.
|
# tests.
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||||
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
NUM_REQUESTS = 10000
|
NUM_REQUESTS = 10000
|
||||||
NUM_EXPECTED_TOKENS = 10
|
NUM_EXPECTED_TOKENS = 10
|
||||||
@ -51,26 +64,33 @@ async def test_load(monkeypatch):
|
|||||||
for request_id in request_ids:
|
for request_id in request_ids:
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
generate(engine, request_id, output_kind,
|
||||||
|
NUM_EXPECTED_TOKENS)))
|
||||||
|
|
||||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||||
for task in tasks:
|
done, pending = await asyncio.wait(tasks,
|
||||||
|
return_when=asyncio.FIRST_EXCEPTION)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
for task in done:
|
||||||
num_generated_tokens, request_id = await task
|
num_generated_tokens, request_id = await task
|
||||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||||
f"{request_id} generated {num_generated_tokens} but "
|
f"{request_id} generated {num_generated_tokens} but "
|
||||||
f"expected {NUM_EXPECTED_TOKENS}")
|
f"expected {NUM_EXPECTED_TOKENS}")
|
||||||
|
|
||||||
assert not engine.output_processor.has_unfinished_requests()
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_abort(monkeypatch):
|
async def test_abort(monkeypatch, output_kind: RequestOutputKind):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||||
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
NUM_REQUESTS = 100
|
NUM_REQUESTS = 100
|
||||||
NUM_EXPECTED_TOKENS = 100
|
NUM_EXPECTED_TOKENS = 100
|
||||||
@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
|
|||||||
for request_id in request_ids:
|
for request_id in request_ids:
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
generate(engine, request_id, output_kind,
|
||||||
|
NUM_EXPECTED_TOKENS)))
|
||||||
|
|
||||||
# API server cancels requests when they disconnect.
|
# API server cancels requests when they disconnect.
|
||||||
for idx in REQUEST_IDS_TO_ABORT:
|
for idx in REQUEST_IDS_TO_ABORT:
|
||||||
@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
|
|||||||
# Confirm we can do another generation.
|
# Confirm we can do another generation.
|
||||||
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
generate(engine, request_id, NUM_EXPECTED_TOKENS))
|
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
|
||||||
num_generated_tokens, request_id = await task
|
num_generated_tokens, request_id = await task
|
||||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||||
assert not engine.output_processor.has_unfinished_requests()
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
engine.shutdown()
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Generic, List, Optional
|
from typing import Dict, Generic, List, MutableSequence, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -162,6 +162,26 @@ class RequestOutput:
|
|||||||
finished=finished,
|
finished=finished,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add(self, next_output: "RequestOutput") -> None:
|
||||||
|
"""Merge subsequent RequestOutput into this one"""
|
||||||
|
|
||||||
|
self.prompt = next_output.prompt
|
||||||
|
self.prompt_token_ids = next_output.prompt_token_ids
|
||||||
|
self.prompt_logprobs = next_output.prompt_logprobs
|
||||||
|
self.finished |= next_output.finished
|
||||||
|
|
||||||
|
#TODO assuming n == 1 for now
|
||||||
|
completion = self.outputs[0]
|
||||||
|
next_completion = next_output.outputs[0]
|
||||||
|
completion.text += next_completion.text
|
||||||
|
if not isinstance(completion.token_ids, MutableSequence):
|
||||||
|
completion.token_ids = list(completion.token_ids)
|
||||||
|
completion.token_ids.extend(next_completion.token_ids)
|
||||||
|
if next_completion.logprobs:
|
||||||
|
assert completion.logprobs is not None
|
||||||
|
completion.logprobs.extend(next_completion.logprobs)
|
||||||
|
completion.cumulative_logprob = next_completion.cumulative_logprob
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(
|
def from_seq_group(
|
||||||
cls, seq_group: SequenceGroup, use_cache: bool,
|
cls, seq_group: SequenceGroup, use_cache: bool,
|
||||||
|
@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
@ -214,6 +214,14 @@ class AsyncLLM(EngineClient):
|
|||||||
# task switching under load which helps performance).
|
# task switching under load which helps performance).
|
||||||
out = q.get_nowait() if not q.empty() else await q.get()
|
out = q.get_nowait() if not q.empty() else await q.get()
|
||||||
|
|
||||||
|
# Coalesce any additional queued outputs
|
||||||
|
while not q.empty():
|
||||||
|
next_out = q.get_nowait()
|
||||||
|
if sampling_params.output_kind == RequestOutputKind.DELTA:
|
||||||
|
out.add(next_out)
|
||||||
|
else:
|
||||||
|
out = next_out
|
||||||
|
|
||||||
# Note: both OutputProcessor and EngineCore handle their
|
# Note: both OutputProcessor and EngineCore handle their
|
||||||
# own request cleanup based on finished.
|
# own request cleanup based on finished.
|
||||||
finished = out.finished
|
finished = out.finished
|
||||||
|
Loading…
x
Reference in New Issue
Block a user