2023-06-14 19:55:38 -07:00
|
|
|
"""Benchmark offline inference throughput."""
|
2023-05-28 03:20:05 -07:00
|
|
|
import argparse
|
|
|
|
import json
|
|
|
|
import random
|
|
|
|
import time
|
2023-09-16 00:03:37 -07:00
|
|
|
from typing import List, Optional, Tuple
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
import torch
|
2024-03-25 23:59:47 +09:00
|
|
|
from tqdm import tqdm
|
2023-11-14 12:35:30 -08:00
|
|
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|
|
|
PreTrainedTokenizerBase)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
|
|
|
|
def sample_requests(
|
|
|
|
dataset_path: str,
|
|
|
|
num_requests: int,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2023-11-14 12:35:30 -08:00
|
|
|
fixed_output_len: Optional[int],
|
2023-06-14 19:55:38 -07:00
|
|
|
) -> List[Tuple[str, int, int]]:
|
2023-11-20 11:58:01 -08:00
|
|
|
if fixed_output_len is not None and fixed_output_len < 4:
|
|
|
|
raise ValueError("output_len too small")
|
2023-11-14 12:35:30 -08:00
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
# Load the dataset.
|
|
|
|
with open(dataset_path) as f:
|
|
|
|
dataset = json.load(f)
|
|
|
|
# Filter out the conversations with less than 2 turns.
|
2023-09-16 00:03:37 -07:00
|
|
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
2023-05-28 03:20:05 -07:00
|
|
|
# Only keep the first two turns of each conversation.
|
2023-09-16 00:03:37 -07:00
|
|
|
dataset = [(data["conversations"][0]["value"],
|
|
|
|
data["conversations"][1]["value"]) for data in dataset]
|
2023-05-28 03:20:05 -07:00
|
|
|
|
|
|
|
# Tokenize the prompts and completions.
|
|
|
|
prompts = [prompt for prompt, _ in dataset]
|
|
|
|
prompt_token_ids = tokenizer(prompts).input_ids
|
|
|
|
completions = [completion for _, completion in dataset]
|
|
|
|
completion_token_ids = tokenizer(completions).input_ids
|
|
|
|
tokenized_dataset = []
|
|
|
|
for i in range(len(dataset)):
|
|
|
|
output_len = len(completion_token_ids[i])
|
2023-11-14 12:35:30 -08:00
|
|
|
if fixed_output_len is not None:
|
|
|
|
output_len = fixed_output_len
|
2023-06-14 19:55:38 -07:00
|
|
|
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
|
|
|
|
|
|
|
# Filter out too long sequences.
|
|
|
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
|
|
|
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
|
|
|
prompt_len = len(prompt_token_ids)
|
|
|
|
if prompt_len < 4 or output_len < 4:
|
|
|
|
# Prune too short sequences.
|
|
|
|
continue
|
|
|
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
|
|
|
# Prune too long sequences.
|
|
|
|
continue
|
|
|
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
2023-05-28 03:20:05 -07:00
|
|
|
|
|
|
|
# Sample the requests.
|
2023-06-14 19:55:38 -07:00
|
|
|
sampled_requests = random.sample(filtered_dataset, num_requests)
|
2023-05-28 03:20:05 -07:00
|
|
|
return sampled_requests
|
|
|
|
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
def run_vllm(
|
2023-06-14 19:55:38 -07:00
|
|
|
requests: List[Tuple[str, int, int]],
|
|
|
|
model: str,
|
2023-06-28 09:46:58 -07:00
|
|
|
tokenizer: str,
|
2023-09-16 00:03:37 -07:00
|
|
|
quantization: Optional[str],
|
2023-06-14 19:55:38 -07:00
|
|
|
tensor_parallel_size: int,
|
|
|
|
seed: int,
|
|
|
|
n: int,
|
|
|
|
use_beam_search: bool,
|
2023-07-20 08:02:40 +08:00
|
|
|
trust_remote_code: bool,
|
2023-10-01 00:04:03 -04:00
|
|
|
dtype: str,
|
2023-12-16 21:12:08 -08:00
|
|
|
max_model_len: Optional[int],
|
|
|
|
enforce_eager: bool,
|
2024-01-29 08:43:54 +08:00
|
|
|
kv_cache_dtype: str,
|
2024-02-02 07:46:39 +08:00
|
|
|
device: str,
|
2024-03-02 03:50:01 -05:00
|
|
|
enable_prefix_caching: bool,
|
2024-03-05 02:37:58 +08:00
|
|
|
gpu_memory_utilization: float = 0.9,
|
2024-03-27 16:39:05 -04:00
|
|
|
download_dir: Optional[str] = None,
|
2023-06-14 19:55:38 -07:00
|
|
|
) -> float:
|
2023-11-14 12:35:30 -08:00
|
|
|
from vllm import LLM, SamplingParams
|
2024-03-02 03:50:01 -05:00
|
|
|
llm = LLM(model=model,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
quantization=quantization,
|
|
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
|
|
seed=seed,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
dtype=dtype,
|
|
|
|
max_model_len=max_model_len,
|
2024-03-05 02:37:58 +08:00
|
|
|
gpu_memory_utilization=gpu_memory_utilization,
|
2024-03-02 03:50:01 -05:00
|
|
|
enforce_eager=enforce_eager,
|
|
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
|
|
device=device,
|
2024-03-27 16:39:05 -04:00
|
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
|
|
download_dir=download_dir)
|
2023-05-28 03:20:05 -07:00
|
|
|
|
2023-06-17 17:25:21 +08:00
|
|
|
# Add the requests to the engine.
|
2023-06-14 19:55:38 -07:00
|
|
|
for prompt, _, output_len in requests:
|
2023-05-28 03:20:05 -07:00
|
|
|
sampling_params = SamplingParams(
|
2023-06-14 19:55:38 -07:00
|
|
|
n=n,
|
|
|
|
temperature=0.0 if use_beam_search else 1.0,
|
2023-05-28 03:20:05 -07:00
|
|
|
top_p=1.0,
|
2023-06-14 19:55:38 -07:00
|
|
|
use_beam_search=use_beam_search,
|
2023-05-28 03:20:05 -07:00
|
|
|
ignore_eos=True,
|
|
|
|
max_tokens=output_len,
|
|
|
|
)
|
|
|
|
# FIXME(woosuk): Do not use internal method.
|
|
|
|
llm._add_request(
|
2023-06-14 19:55:38 -07:00
|
|
|
prompt=prompt,
|
2023-06-16 21:00:52 -07:00
|
|
|
prompt_token_ids=None,
|
2023-06-04 12:52:41 -07:00
|
|
|
sampling_params=sampling_params,
|
2023-05-28 03:20:05 -07:00
|
|
|
)
|
|
|
|
|
2023-10-02 19:22:05 -07:00
|
|
|
start = time.perf_counter()
|
2023-11-18 01:42:49 +08:00
|
|
|
# FIXME(woosuk): Do not use internal method.
|
2023-06-17 17:25:21 +08:00
|
|
|
llm._run_engine(use_tqdm=True)
|
2023-10-02 19:22:05 -07:00
|
|
|
end = time.perf_counter()
|
2023-06-14 19:55:38 -07:00
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
|
|
|
def run_hf(
|
|
|
|
requests: List[Tuple[str, int, int]],
|
|
|
|
model: str,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
n: int,
|
|
|
|
use_beam_search: bool,
|
|
|
|
max_batch_size: int,
|
2023-07-20 08:02:40 +08:00
|
|
|
trust_remote_code: bool,
|
2023-06-14 19:55:38 -07:00
|
|
|
) -> float:
|
|
|
|
assert not use_beam_search
|
2023-09-16 00:03:37 -07:00
|
|
|
llm = AutoModelForCausalLM.from_pretrained(
|
|
|
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
2023-06-28 09:46:58 -07:00
|
|
|
if llm.config.model_type == "llama":
|
|
|
|
# To enable padding in the HF backend.
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
2023-06-14 19:55:38 -07:00
|
|
|
llm = llm.cuda()
|
|
|
|
|
|
|
|
pbar = tqdm(total=len(requests))
|
2023-10-02 19:22:05 -07:00
|
|
|
start = time.perf_counter()
|
2023-06-14 19:55:38 -07:00
|
|
|
batch: List[str] = []
|
|
|
|
max_prompt_len = 0
|
|
|
|
max_output_len = 0
|
|
|
|
for i in range(len(requests)):
|
|
|
|
prompt, prompt_len, output_len = requests[i]
|
|
|
|
# Add the prompt to the batch.
|
|
|
|
batch.append(prompt)
|
|
|
|
max_prompt_len = max(max_prompt_len, prompt_len)
|
|
|
|
max_output_len = max(max_output_len, output_len)
|
|
|
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
|
|
|
# Check if we can add more requests to the batch.
|
|
|
|
_, next_prompt_len, next_output_len = requests[i + 1]
|
2023-09-16 00:03:37 -07:00
|
|
|
if (max(max_prompt_len, next_prompt_len) +
|
|
|
|
max(max_output_len, next_output_len)) <= 2048:
|
2023-06-14 19:55:38 -07:00
|
|
|
# We can add more requests to the batch.
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Generate the sequences.
|
2023-09-16 00:03:37 -07:00
|
|
|
input_ids = tokenizer(batch, return_tensors="pt",
|
|
|
|
padding=True).input_ids
|
2023-06-14 19:55:38 -07:00
|
|
|
llm_outputs = llm.generate(
|
|
|
|
input_ids=input_ids.cuda(),
|
|
|
|
do_sample=not use_beam_search,
|
|
|
|
num_return_sequences=n,
|
|
|
|
temperature=1.0,
|
|
|
|
top_p=1.0,
|
|
|
|
use_cache=True,
|
|
|
|
max_new_tokens=max_output_len,
|
|
|
|
)
|
|
|
|
# Include the decoding time.
|
|
|
|
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
|
|
|
pbar.update(len(batch))
|
|
|
|
|
|
|
|
# Clear the batch.
|
|
|
|
batch = []
|
|
|
|
max_prompt_len = 0
|
|
|
|
max_output_len = 0
|
2023-10-02 19:22:05 -07:00
|
|
|
end = time.perf_counter()
|
2023-06-14 19:55:38 -07:00
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
2023-11-14 12:35:30 -08:00
|
|
|
def run_mii(
|
|
|
|
requests: List[Tuple[str, int, int]],
|
|
|
|
model: str,
|
|
|
|
tensor_parallel_size: int,
|
|
|
|
output_len: int,
|
|
|
|
) -> float:
|
|
|
|
from mii import pipeline
|
|
|
|
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
|
|
|
prompts = [prompt for prompt, _, _ in requests]
|
|
|
|
|
|
|
|
start = time.perf_counter()
|
|
|
|
llm(prompts, max_new_tokens=output_len)
|
|
|
|
end = time.perf_counter()
|
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
2023-06-14 19:55:38 -07:00
|
|
|
def main(args: argparse.Namespace):
|
|
|
|
print(args)
|
|
|
|
random.seed(args.seed)
|
|
|
|
|
|
|
|
# Sample the requests.
|
2023-11-14 12:35:30 -08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
|
|
|
if args.dataset is None:
|
|
|
|
# Synthesize a prompt with the given input length.
|
|
|
|
prompt = "hi" * (args.input_len - 1)
|
|
|
|
requests = [(prompt, args.input_len, args.output_len)
|
|
|
|
for _ in range(args.num_prompts)]
|
|
|
|
else:
|
|
|
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
|
|
|
args.output_len)
|
2023-06-14 19:55:38 -07:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
if args.backend == "vllm":
|
2024-03-27 16:39:05 -04:00
|
|
|
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
|
|
|
args.quantization, args.tensor_parallel_size,
|
|
|
|
args.seed, args.n, args.use_beam_search,
|
|
|
|
args.trust_remote_code, args.dtype,
|
|
|
|
args.max_model_len, args.enforce_eager,
|
|
|
|
args.kv_cache_dtype, args.device,
|
|
|
|
args.enable_prefix_caching,
|
|
|
|
args.gpu_memory_utilization, args.download_dir)
|
2023-06-14 19:55:38 -07:00
|
|
|
elif args.backend == "hf":
|
|
|
|
assert args.tensor_parallel_size == 1
|
2023-09-16 00:03:37 -07:00
|
|
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
|
|
|
args.use_beam_search, args.hf_max_batch_size,
|
|
|
|
args.trust_remote_code)
|
2023-11-14 12:35:30 -08:00
|
|
|
elif args.backend == "mii":
|
|
|
|
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
|
|
|
args.output_len)
|
2023-06-14 19:55:38 -07:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown backend: {args.backend}")
|
2023-09-16 00:03:37 -07:00
|
|
|
total_num_tokens = sum(prompt_len + output_len
|
|
|
|
for _, prompt_len, output_len in requests)
|
2023-06-04 12:52:41 -07:00
|
|
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
|
|
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
2023-05-28 03:20:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--backend",
|
|
|
|
type=str,
|
2023-11-14 12:35:30 -08:00
|
|
|
choices=["vllm", "hf", "mii"],
|
2023-06-17 03:07:40 -07:00
|
|
|
default="vllm")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--dataset",
|
|
|
|
type=str,
|
2023-11-14 12:35:30 -08:00
|
|
|
default=None,
|
2023-05-28 03:20:05 -07:00
|
|
|
help="Path to the dataset.")
|
2023-11-14 12:35:30 -08:00
|
|
|
parser.add_argument("--input-len",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Input prompt length for each request")
|
|
|
|
parser.add_argument("--output-len",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help="Output length for each request. Overrides the "
|
|
|
|
"output length from the dataset.")
|
2023-05-28 03:20:05 -07:00
|
|
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
2023-06-28 09:46:58 -07:00
|
|
|
parser.add_argument("--tokenizer", type=str, default=None)
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument('--quantization',
|
|
|
|
'-q',
|
2023-12-15 19:04:22 +08:00
|
|
|
choices=['awq', 'gptq', 'squeezellm', None],
|
2023-09-16 00:03:37 -07:00
|
|
|
default=None)
|
2023-05-28 03:20:05 -07:00
|
|
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--n",
|
|
|
|
type=int,
|
|
|
|
default=1,
|
2023-05-28 03:20:05 -07:00
|
|
|
help="Number of generated sequences per prompt.")
|
|
|
|
parser.add_argument("--use-beam-search", action="store_true")
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--num-prompts",
|
|
|
|
type=int,
|
|
|
|
default=1000,
|
2023-05-28 03:20:05 -07:00
|
|
|
help="Number of prompts to process.")
|
|
|
|
parser.add_argument("--seed", type=int, default=0)
|
2023-09-16 00:03:37 -07:00
|
|
|
parser.add_argument("--hf-max-batch-size",
|
|
|
|
type=int,
|
|
|
|
default=None,
|
2023-06-14 19:55:38 -07:00
|
|
|
help="Maximum batch size for HF backend.")
|
2023-07-20 08:02:40 +08:00
|
|
|
parser.add_argument('--trust-remote-code',
|
|
|
|
action='store_true',
|
|
|
|
help='trust remote code from huggingface')
|
2023-12-01 00:10:24 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--max-model-len',
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help='Maximum length of a sequence (including prompt and output). '
|
|
|
|
'If None, will be derived from the model.')
|
2023-10-01 00:04:03 -04:00
|
|
|
parser.add_argument(
|
|
|
|
'--dtype',
|
|
|
|
type=str,
|
|
|
|
default='auto',
|
|
|
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
|
|
help='data type for model weights and activations. '
|
|
|
|
'The "auto" option will use FP16 precision '
|
|
|
|
'for FP32 and FP16 models, and BF16 precision '
|
|
|
|
'for BF16 models.')
|
2024-03-05 02:37:58 +08:00
|
|
|
parser.add_argument('--gpu-memory-utilization',
|
|
|
|
type=float,
|
|
|
|
default=0.9,
|
|
|
|
help='the fraction of GPU memory to be used for '
|
|
|
|
'the model executor, which can range from 0 to 1.'
|
|
|
|
'If unspecified, will use the default value of 0.9.')
|
2023-12-16 21:12:08 -08:00
|
|
|
parser.add_argument("--enforce-eager",
|
|
|
|
action="store_true",
|
|
|
|
help="enforce eager execution")
|
2024-01-29 08:43:54 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--kv-cache-dtype",
|
|
|
|
type=str,
|
|
|
|
choices=["auto", "fp8_e5m2"],
|
|
|
|
default="auto",
|
|
|
|
help=
|
|
|
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
2024-02-02 07:46:39 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--device",
|
|
|
|
type=str,
|
|
|
|
default="cuda",
|
|
|
|
choices=["cuda"],
|
|
|
|
help='device type for vLLM execution, supporting CUDA only currently.')
|
2024-03-03 14:37:18 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--enable-prefix-caching",
|
|
|
|
action='store_true',
|
|
|
|
help="enable automatic prefix caching for vLLM backend.")
|
2024-03-27 16:39:05 -04:00
|
|
|
parser.add_argument('--download-dir',
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help='directory to download and load the weights, '
|
|
|
|
'default to the default cache dir of huggingface')
|
2023-05-28 03:20:05 -07:00
|
|
|
args = parser.parse_args()
|
2023-11-14 12:35:30 -08:00
|
|
|
if args.tokenizer is None:
|
|
|
|
args.tokenizer = args.model
|
|
|
|
if args.dataset is None:
|
|
|
|
assert args.input_len is not None
|
|
|
|
assert args.output_len is not None
|
|
|
|
else:
|
|
|
|
assert args.input_len is None
|
2023-06-28 09:46:58 -07:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
if args.backend == "vllm":
|
2023-06-14 19:55:38 -07:00
|
|
|
if args.hf_max_batch_size is not None:
|
|
|
|
raise ValueError("HF max batch size is only for HF backend.")
|
|
|
|
elif args.backend == "hf":
|
|
|
|
if args.hf_max_batch_size is None:
|
|
|
|
raise ValueError("HF max batch size is required for HF backend.")
|
2023-09-16 00:03:37 -07:00
|
|
|
if args.quantization is not None:
|
|
|
|
raise ValueError("Quantization is only for vLLM backend.")
|
2023-11-14 12:35:30 -08:00
|
|
|
elif args.backend == "mii":
|
|
|
|
if args.dtype != "auto":
|
|
|
|
raise ValueError("dtype must be auto for MII backend.")
|
|
|
|
if args.n != 1:
|
|
|
|
raise ValueError("n must be 1 for MII backend.")
|
|
|
|
if args.use_beam_search:
|
|
|
|
raise ValueError("Beam search is not supported for MII backend.")
|
|
|
|
if args.quantization is not None:
|
|
|
|
raise ValueError("Quantization is only for vLLM backend.")
|
|
|
|
if args.hf_max_batch_size is not None:
|
|
|
|
raise ValueError("HF max batch size is only for HF backend.")
|
|
|
|
if args.tokenizer != args.model:
|
|
|
|
raise ValueError("Tokenizer must be the same as the model for MII "
|
|
|
|
"backend.")
|
2023-05-28 03:20:05 -07:00
|
|
|
main(args)
|