[Misc] Refactor benchmark_throughput.py (#9779)

Signed-off-by: Linkun Chen <github+anyscale@lkchen.net>
Co-authored-by: Linkun Chen <lkchen@github.com>
Co-authored-by: Linkun Chen <github+anyscale@lkchen.net>
This commit is contained in:
lkchen 2024-11-04 14:32:16 -08:00 committed by GitHub
parent 04cef2c6ab
commit 9a5664d4a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import dataclasses
import json import json
import random import random
import time import time
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
import uvloop import uvloop
@ -15,16 +15,35 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]: ) -> List[SampleRequest]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
@ -41,7 +60,7 @@ def sample_requests(
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)): for i in range(len(dataset)):
if len(filtered_dataset) == num_requests: if len(filtered_dataset) == num_requests:
break break
@ -60,13 +79,16 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048: if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences. # Prune too long sequences.
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))
return filtered_dataset return filtered_dataset
def run_vllm( def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[SampleRequest],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
) -> float: ) -> float:
@ -74,17 +96,17 @@ def run_vllm(
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[str] = [] prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests: for request in requests:
prompts.append(prompt) prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=request.expected_output_len,
)) ))
use_beam_search = False use_beam_search = False
@ -94,11 +116,11 @@ def run_vllm(
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
prompts = [prompt for prompt, _, _ in requests] prompts = [request.prompt for request in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0][2] output_len = requests[0][2]
for prompt, input_len, _output_len in requests: for request in requests:
assert _output_len == output_len assert request.expected_output_len == output_len
start = time.perf_counter() start = time.perf_counter()
llm.beam_search( llm.beam_search(
prompts, prompts,
@ -112,7 +134,7 @@ def run_vllm(
async def run_vllm_async( async def run_vllm_async(
requests: List[Tuple[str, int, int]], requests: List[SampleRequest],
n: int, n: int,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
@ -123,17 +145,17 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm: engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[str] = [] prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests: for request in requests:
prompts.append(prompt) prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=request.expected_output_len,
)) ))
generators = [] generators = []
@ -149,7 +171,7 @@ async def run_vllm_async(
def run_hf( def run_hf(
requests: List[Tuple[str, int, int]], requests: List[SampleRequest],
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
n: int, n: int,
@ -207,14 +229,14 @@ def run_hf(
def run_mii( def run_mii(
requests: List[Tuple[str, int, int]], requests: List[SampleRequest],
model: str, model: str,
tensor_parallel_size: int, tensor_parallel_size: int,
output_len: int, output_len: int,
) -> float: ) -> float:
from mii import client, serve from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size) llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests] prompts = [request.prompt for request in requests]
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len) llm.generate(prompts, max_new_tokens=output_len)
@ -243,8 +265,12 @@ def main(args: argparse.Namespace):
else: else:
raise ValueError( raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.") f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len) requests = [
for _ in range(args.num_prompts)] SampleRequest(prompt=prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len)
for _ in range(args.num_prompts)
]
else: else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len) args.output_len)
@ -270,9 +296,10 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(request.prompt_len + request.expected_output_len
for _, prompt_len, output_len in requests) for request in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests) total_output_tokens = sum(request.expected_output_len
for request in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s") f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
@ -299,7 +326,9 @@ if __name__ == "__main__":
parser.add_argument("--dataset", parser.add_argument("--dataset",
type=str, type=str,
default=None, default=None,
help="Path to the dataset.") help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len", parser.add_argument("--input-len",
type=int, type=int,
default=None, default=None,