[Misc] Refactor benchmark_throughput.py (#9779)
Signed-off-by: Linkun Chen <github+anyscale@lkchen.net> Co-authored-by: Linkun Chen <lkchen@github.com> Co-authored-by: Linkun Chen <github+anyscale@lkchen.net>
This commit is contained in:
parent
04cef2c6ab
commit
9a5664d4a4
@ -4,7 +4,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
@ -15,16 +15,35 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
build_async_engine_client_from_engine_args)
|
build_async_engine_client_from_engine_args)
|
||||||
|
from vllm.inputs import TextPrompt
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SampleRequest:
|
||||||
|
"""A class representing a single inference request for benchmarking.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt: The input text prompt for the model.
|
||||||
|
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||||
|
images).
|
||||||
|
prompt_len: The length of the prompt in tokens.
|
||||||
|
expected_output_len: The expected length of the output in tokens.
|
||||||
|
"""
|
||||||
|
prompt: str
|
||||||
|
prompt_len: int
|
||||||
|
expected_output_len: int
|
||||||
|
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int],
|
fixed_output_len: Optional[int],
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[SampleRequest]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
@ -41,7 +60,7 @@ def sample_requests(
|
|||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_dataset: List[SampleRequest] = []
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
if len(filtered_dataset) == num_requests:
|
if len(filtered_dataset) == num_requests:
|
||||||
break
|
break
|
||||||
@ -60,13 +79,16 @@ def sample_requests(
|
|||||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
# Prune too long sequences.
|
# Prune too long sequences.
|
||||||
continue
|
continue
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_dataset.append(
|
||||||
|
SampleRequest(prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len))
|
||||||
|
|
||||||
return filtered_dataset
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -74,17 +96,17 @@ def run_vllm(
|
|||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[str] = []
|
prompts: List[TextPrompt] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: List[SamplingParams] = []
|
||||||
for prompt, _, output_len in requests:
|
for request in requests:
|
||||||
prompts.append(prompt)
|
prompts.append(TextPrompt(prompt=request.prompt))
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=output_len,
|
max_tokens=request.expected_output_len,
|
||||||
))
|
))
|
||||||
|
|
||||||
use_beam_search = False
|
use_beam_search = False
|
||||||
@ -94,11 +116,11 @@ def run_vllm(
|
|||||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
else:
|
else:
|
||||||
prompts = [prompt for prompt, _, _ in requests]
|
prompts = [request.prompt for request in requests]
|
||||||
# output_len should be the same for all requests.
|
# output_len should be the same for all requests.
|
||||||
output_len = requests[0][2]
|
output_len = requests[0][2]
|
||||||
for prompt, input_len, _output_len in requests:
|
for request in requests:
|
||||||
assert _output_len == output_len
|
assert request.expected_output_len == output_len
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
llm.beam_search(
|
llm.beam_search(
|
||||||
prompts,
|
prompts,
|
||||||
@ -112,7 +134,7 @@ def run_vllm(
|
|||||||
|
|
||||||
|
|
||||||
async def run_vllm_async(
|
async def run_vllm_async(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
@ -123,17 +145,17 @@ async def run_vllm_async(
|
|||||||
engine_args, disable_frontend_multiprocessing) as llm:
|
engine_args, disable_frontend_multiprocessing) as llm:
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[str] = []
|
prompts: List[TextPrompt] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: List[SamplingParams] = []
|
||||||
for prompt, _, output_len in requests:
|
for request in requests:
|
||||||
prompts.append(prompt)
|
prompts.append(TextPrompt(prompt=request.prompt))
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=output_len,
|
max_tokens=request.expected_output_len,
|
||||||
))
|
))
|
||||||
|
|
||||||
generators = []
|
generators = []
|
||||||
@ -149,7 +171,7 @@ async def run_vllm_async(
|
|||||||
|
|
||||||
|
|
||||||
def run_hf(
|
def run_hf(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[SampleRequest],
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
n: int,
|
n: int,
|
||||||
@ -207,14 +229,14 @@ def run_hf(
|
|||||||
|
|
||||||
|
|
||||||
def run_mii(
|
def run_mii(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: List[SampleRequest],
|
||||||
model: str,
|
model: str,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
output_len: int,
|
output_len: int,
|
||||||
) -> float:
|
) -> float:
|
||||||
from mii import client, serve
|
from mii import client, serve
|
||||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||||
prompts = [prompt for prompt, _, _ in requests]
|
prompts = [request.prompt for request in requests]
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
llm.generate(prompts, max_new_tokens=output_len)
|
llm.generate(prompts, max_new_tokens=output_len)
|
||||||
@ -243,8 +265,12 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
||||||
requests = [(prompt, args.input_len, args.output_len)
|
requests = [
|
||||||
for _ in range(args.num_prompts)]
|
SampleRequest(prompt=prompt,
|
||||||
|
prompt_len=args.input_len,
|
||||||
|
expected_output_len=args.output_len)
|
||||||
|
for _ in range(args.num_prompts)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||||
args.output_len)
|
args.output_len)
|
||||||
@ -270,9 +296,10 @@ def main(args: argparse.Namespace):
|
|||||||
args.output_len)
|
args.output_len)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(prompt_len + output_len
|
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||||
for _, prompt_len, output_len in requests)
|
for request in requests)
|
||||||
total_output_tokens = sum(output_len for _, _, output_len in requests)
|
total_output_tokens = sum(request.expected_output_len
|
||||||
|
for request in requests)
|
||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||||
@ -299,7 +326,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--dataset",
|
parser.add_argument("--dataset",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset. The dataset is expected to "
|
||||||
|
"be a json in form of List[Dict[..., conversations: "
|
||||||
|
"List[Dict[..., value: <prompt_or_response>]]]]")
|
||||||
parser.add_argument("--input-len",
|
parser.add_argument("--input-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user