296 lines
11 KiB
Python
296 lines
11 KiB
Python
![]() |
"""Benchmark offline prioritization."""
|
||
|
import argparse
|
||
|
import json
|
||
|
import random
|
||
|
import time
|
||
|
from typing import List, Optional, Tuple
|
||
|
|
||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||
|
|
||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||
|
|
||
|
|
||
|
def sample_requests(
|
||
|
dataset_path: str,
|
||
|
num_requests: int,
|
||
|
tokenizer: PreTrainedTokenizerBase,
|
||
|
fixed_output_len: Optional[int],
|
||
|
) -> List[Tuple[str, int, int]]:
|
||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||
|
raise ValueError("output_len too small")
|
||
|
|
||
|
# Load the dataset.
|
||
|
with open(dataset_path) as f:
|
||
|
dataset = json.load(f)
|
||
|
# Filter out the conversations with less than 2 turns.
|
||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||
|
# Only keep the first two turns of each conversation.
|
||
|
dataset = [(data["conversations"][0]["value"],
|
||
|
data["conversations"][1]["value"]) for data in dataset]
|
||
|
|
||
|
# Shuffle the dataset.
|
||
|
random.shuffle(dataset)
|
||
|
|
||
|
# Filter out sequences that are too long or too short
|
||
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||
|
for i in range(len(dataset)):
|
||
|
if len(filtered_dataset) == num_requests:
|
||
|
break
|
||
|
|
||
|
# Tokenize the prompts and completions.
|
||
|
prompt = dataset[i][0]
|
||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||
|
completion = dataset[i][1]
|
||
|
completion_token_ids = tokenizer(completion).input_ids
|
||
|
prompt_len = len(prompt_token_ids)
|
||
|
output_len = len(completion_token_ids
|
||
|
) if fixed_output_len is None else fixed_output_len
|
||
|
if prompt_len < 4 or output_len < 4:
|
||
|
# Prune too short sequences.
|
||
|
continue
|
||
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||
|
# Prune too long sequences.
|
||
|
continue
|
||
|
|
||
|
#Select a equi-probable random priority
|
||
|
priority = 0 if random.random() < 0.5 else 1
|
||
|
|
||
|
filtered_dataset.append((prompt, prompt_len, output_len, priority))
|
||
|
|
||
|
return filtered_dataset
|
||
|
|
||
|
|
||
|
def run_vllm(
|
||
|
requests: List[Tuple[str, int, int]],
|
||
|
model: str,
|
||
|
tokenizer: str,
|
||
|
quantization: Optional[str],
|
||
|
tensor_parallel_size: int,
|
||
|
seed: int,
|
||
|
n: int,
|
||
|
use_beam_search: bool,
|
||
|
trust_remote_code: bool,
|
||
|
dtype: str,
|
||
|
max_model_len: Optional[int],
|
||
|
enforce_eager: bool,
|
||
|
kv_cache_dtype: str,
|
||
|
quantization_param_path: Optional[str],
|
||
|
device: str,
|
||
|
enable_prefix_caching: bool,
|
||
|
enable_chunked_prefill: bool,
|
||
|
max_num_batched_tokens: int,
|
||
|
gpu_memory_utilization: float = 0.9,
|
||
|
download_dir: Optional[str] = None,
|
||
|
) -> float:
|
||
|
from vllm import LLM, SamplingParams
|
||
|
llm = LLM(
|
||
|
model=model,
|
||
|
tokenizer=tokenizer,
|
||
|
quantization=quantization,
|
||
|
tensor_parallel_size=tensor_parallel_size,
|
||
|
seed=seed,
|
||
|
trust_remote_code=trust_remote_code,
|
||
|
dtype=dtype,
|
||
|
max_model_len=max_model_len,
|
||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||
|
enforce_eager=enforce_eager,
|
||
|
kv_cache_dtype=kv_cache_dtype,
|
||
|
quantization_param_path=quantization_param_path,
|
||
|
device=device,
|
||
|
enable_prefix_caching=enable_prefix_caching,
|
||
|
download_dir=download_dir,
|
||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||
|
disable_log_stats=False,
|
||
|
)
|
||
|
|
||
|
# Add the requests to the engine.
|
||
|
prompts = []
|
||
|
sampling_params = []
|
||
|
priority = []
|
||
|
for prompt, _, output_len, _priority in requests:
|
||
|
prompts.append(prompt)
|
||
|
priority.append(_priority)
|
||
|
sampling_params.append(
|
||
|
SamplingParams(
|
||
|
n=n,
|
||
|
temperature=0.0 if use_beam_search else 1.0,
|
||
|
top_p=1.0,
|
||
|
use_beam_search=use_beam_search,
|
||
|
ignore_eos=True,
|
||
|
max_tokens=output_len,
|
||
|
))
|
||
|
|
||
|
start = time.perf_counter()
|
||
|
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
|
||
|
end = time.perf_counter()
|
||
|
return end - start
|
||
|
|
||
|
|
||
|
def main(args: argparse.Namespace):
|
||
|
print(args)
|
||
|
random.seed(args.seed)
|
||
|
|
||
|
# Sample the requests.
|
||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||
|
if args.dataset is None:
|
||
|
# Synthesize a prompt with the given input length.
|
||
|
prompt = "hi" * (args.input_len - 1)
|
||
|
requests = [(prompt, args.input_len, args.output_len)
|
||
|
for _ in range(args.num_prompts)]
|
||
|
else:
|
||
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||
|
args.output_len)
|
||
|
|
||
|
if args.backend == "vllm":
|
||
|
elapsed_time = run_vllm(
|
||
|
requests, args.model, args.tokenizer, args.quantization,
|
||
|
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||
|
args.trust_remote_code, args.dtype, args.max_model_len,
|
||
|
args.enforce_eager, args.kv_cache_dtype,
|
||
|
args.quantization_param_path, args.device,
|
||
|
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||
|
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
||
|
args.download_dir)
|
||
|
else:
|
||
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||
|
total_num_tokens = sum(prompt_len + output_len
|
||
|
for _, prompt_len, output_len, priority in requests)
|
||
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||
|
|
||
|
# Output JSON results if specified
|
||
|
if args.output_json:
|
||
|
results = {
|
||
|
"elapsed_time": elapsed_time,
|
||
|
"num_requests": len(requests),
|
||
|
"total_num_tokens": total_num_tokens,
|
||
|
"requests_per_second": len(requests) / elapsed_time,
|
||
|
"tokens_per_second": total_num_tokens / elapsed_time,
|
||
|
}
|
||
|
with open(args.output_json, "w") as f:
|
||
|
json.dump(results, f, indent=4)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||
|
parser.add_argument("--backend",
|
||
|
type=str,
|
||
|
choices=["vllm", "hf", "mii"],
|
||
|
default="vllm")
|
||
|
parser.add_argument("--dataset",
|
||
|
type=str,
|
||
|
default=None,
|
||
|
help="Path to the dataset.")
|
||
|
parser.add_argument("--input-len",
|
||
|
type=int,
|
||
|
default=None,
|
||
|
help="Input prompt length for each request")
|
||
|
parser.add_argument("--output-len",
|
||
|
type=int,
|
||
|
default=None,
|
||
|
help="Output length for each request. Overrides the "
|
||
|
"output length from the dataset.")
|
||
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||
|
parser.add_argument('--quantization',
|
||
|
'-q',
|
||
|
choices=[*QUANTIZATION_METHODS, None],
|
||
|
default=None)
|
||
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||
|
parser.add_argument("--n",
|
||
|
type=int,
|
||
|
default=1,
|
||
|
help="Number of generated sequences per prompt.")
|
||
|
parser.add_argument("--use-beam-search", action="store_true")
|
||
|
parser.add_argument("--num-prompts",
|
||
|
type=int,
|
||
|
default=200,
|
||
|
help="Number of prompts to process.")
|
||
|
parser.add_argument("--seed", type=int, default=0)
|
||
|
parser.add_argument('--trust-remote-code',
|
||
|
action='store_true',
|
||
|
help='trust remote code from huggingface')
|
||
|
parser.add_argument(
|
||
|
'--max-model-len',
|
||
|
type=int,
|
||
|
default=None,
|
||
|
help='Maximum length of a sequence (including prompt and output). '
|
||
|
'If None, will be derived from the model.')
|
||
|
parser.add_argument(
|
||
|
'--dtype',
|
||
|
type=str,
|
||
|
default='auto',
|
||
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||
|
help='data type for model weights and activations. '
|
||
|
'The "auto" option will use FP16 precision '
|
||
|
'for FP32 and FP16 models, and BF16 precision '
|
||
|
'for BF16 models.')
|
||
|
parser.add_argument('--gpu-memory-utilization',
|
||
|
type=float,
|
||
|
default=0.9,
|
||
|
help='the fraction of GPU memory to be used for '
|
||
|
'the model executor, which can range from 0 to 1.'
|
||
|
'If unspecified, will use the default value of 0.9.')
|
||
|
parser.add_argument("--enforce-eager",
|
||
|
action="store_true",
|
||
|
help="enforce eager execution")
|
||
|
parser.add_argument(
|
||
|
'--kv-cache-dtype',
|
||
|
type=str,
|
||
|
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||
|
default="auto",
|
||
|
help='Data type for kv cache storage. If "auto", will use model '
|
||
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||
|
parser.add_argument(
|
||
|
'--quantization-param-path',
|
||
|
type=str,
|
||
|
default=None,
|
||
|
help='Path to the JSON file containing the KV cache scaling factors. '
|
||
|
'This should generally be supplied, when KV cache dtype is FP8. '
|
||
|
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
||
|
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||
|
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||
|
'instead supported for common inference criteria.')
|
||
|
parser.add_argument(
|
||
|
"--device",
|
||
|
type=str,
|
||
|
default="cuda",
|
||
|
choices=["cuda", "cpu"],
|
||
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||
|
parser.add_argument(
|
||
|
"--enable-prefix-caching",
|
||
|
action='store_true',
|
||
|
help="enable automatic prefix caching for vLLM backend.")
|
||
|
parser.add_argument("--enable-chunked-prefill",
|
||
|
action='store_true',
|
||
|
help="enable chunked prefill for vLLM backend.")
|
||
|
parser.add_argument('--max-num-batched-tokens',
|
||
|
type=int,
|
||
|
default=None,
|
||
|
help='maximum number of batched tokens per '
|
||
|
'iteration')
|
||
|
parser.add_argument('--download-dir',
|
||
|
type=str,
|
||
|
default=None,
|
||
|
help='directory to download and load the weights, '
|
||
|
'default to the default cache dir of huggingface')
|
||
|
parser.add_argument(
|
||
|
'--output-json',
|
||
|
type=str,
|
||
|
default=None,
|
||
|
help='Path to save the throughput results in JSON format.')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
if args.tokenizer is None:
|
||
|
args.tokenizer = args.model
|
||
|
if args.dataset is None:
|
||
|
assert args.input_len is not None
|
||
|
assert args.output_len is not None
|
||
|
else:
|
||
|
assert args.input_len is None
|
||
|
|
||
|
main(args)
|