[Benchmark] Add --async-engine
option to benchmark_throughput.py (#7964)
This commit is contained in:
parent
2188a60c7e
commit
d4db9f53c8
@ -6,13 +6,16 @@ import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -135,6 +138,93 @@ def run_vllm(
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
run_args = [
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc)
|
||||
args.disable_async_output_proc
|
||||
]
|
||||
|
||||
if args.async_engine:
|
||||
run_args.append(args.disable_frontend_multiprocessing)
|
||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||
else:
|
||||
elapsed_time = run_vllm(*run_args)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -426,6 +523,14 @@ if __name__ == "__main__":
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable async output processor for vLLM backend.")
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def model_is_embedding(model_name: str, trust_remote_code: bool,
|
||||
quantization: str) -> bool:
|
||||
quantization: Optional[str]) -> bool:
|
||||
return ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
@ -96,13 +96,6 @@ async def lifespan(app: FastAPI):
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
||||
"""
|
||||
Create AsyncEngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
@ -112,14 +105,37 @@ async def build_async_engine_client(
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
|
||||
async_engine_client = engine # type: ignore[assignment]
|
||||
yield engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
||||
"""
|
||||
Create AsyncEngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||
# TODO: support embedding model via RPC.
|
||||
if (model_is_embedding(args.model, args.trust_remote_code,
|
||||
args.quantization)
|
||||
or args.disable_frontend_multiprocessing):
|
||||
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
||||
engine_args.quantization)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
yield async_engine_client
|
||||
try:
|
||||
yield engine_client
|
||||
finally:
|
||||
engine_client.shutdown_background_loop()
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
@ -148,7 +164,6 @@ async def build_async_engine_client(
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
async_engine_client = rpc_client # type: ignore
|
||||
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
context = multiprocessing.get_context("spawn")
|
||||
@ -174,7 +189,7 @@ async def build_async_engine_client(
|
||||
yield None
|
||||
return
|
||||
|
||||
yield async_engine_client
|
||||
yield rpc_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
rpc_server_process.terminate()
|
||||
|
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
@ -214,6 +215,7 @@ class AsyncEngineRPCClient:
|
||||
|
||||
# Await the data from the Server.
|
||||
frame = await socket.recv(copy=False)
|
||||
assert isinstance(frame, Frame)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, Exception):
|
||||
@ -247,6 +249,7 @@ class AsyncEngineRPCClient:
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
frame = await socket.recv(copy=False)
|
||||
assert isinstance(frame, Frame)
|
||||
return pickle.loads(frame.buffer)
|
||||
|
||||
# Make a new socket connection.
|
||||
@ -395,6 +398,7 @@ class AsyncEngineRPCClient:
|
||||
# Stream back the results from the RPC Server.
|
||||
while not finished:
|
||||
message = await socket.recv(copy=False)
|
||||
assert isinstance(message, Frame)
|
||||
request_output = pickle.loads(message.buffer)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
|
Loading…
x
Reference in New Issue
Block a user