[Misc] Refactor benchmark_throughput.py (#9779)
Signed-off-by: Linkun Chen <github+anyscale@lkchen.net> Co-authored-by: Linkun Chen <lkchen@github.com> Co-authored-by: Linkun Chen <github+anyscale@lkchen.net>
This commit is contained in:
parent
04cef2c6ab
commit
9a5664d4a4
@ -4,7 +4,7 @@ import dataclasses
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
@ -15,16 +15,35 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
) -> List[SampleRequest]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
@ -41,7 +60,7 @@ def sample_requests(
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
filtered_dataset: List[SampleRequest] = []
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
@ -60,13 +79,16 @@ def sample_requests(
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
filtered_dataset.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
@ -74,17 +96,17 @@ def run_vllm(
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
for request in requests:
|
||||
prompts.append(TextPrompt(prompt=request.prompt))
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
|
||||
use_beam_search = False
|
||||
@ -94,11 +116,11 @@ def run_vllm(
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
for prompt, input_len, _output_len in requests:
|
||||
assert _output_len == output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
@ -112,7 +134,7 @@ def run_vllm(
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
@ -123,17 +145,17 @@ async def run_vllm_async(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
for request in requests:
|
||||
prompts.append(TextPrompt(prompt=request.prompt))
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
@ -149,7 +171,7 @@ async def run_vllm_async(
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
@ -207,14 +229,14 @@ def run_hf(
|
||||
|
||||
|
||||
def run_mii(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import client, serve
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
prompts = [request.prompt for request in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, max_new_tokens=output_len)
|
||||
@ -243,8 +265,12 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
||||
requests = [(prompt, args.input_len, args.output_len)
|
||||
for _ in range(args.num_prompts)]
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
else:
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
@ -270,9 +296,10 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
total_output_tokens = sum(output_len for _, _, output_len in requests)
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
@ -299,7 +326,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
help="Path to the dataset. The dataset is expected to "
|
||||
"be a json in form of List[Dict[..., conversations: "
|
||||
"List[Dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user