Add DeepSpeed MII backend to benchmark script (#1649)

This commit is contained in:
Woosuk Kwon 2023-11-14 12:35:30 -08:00 committed by GitHub
parent 054072bee5
commit 660a7fcfa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,18 +6,21 @@ import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None:
if fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
@ -35,6 +38,8 @@ def sample_requests(
tokenized_dataset = [] tokenized_dataset = []
for i in range(len(dataset)): for i in range(len(dataset)):
output_len = len(completion_token_ids[i]) output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences. # Filter out too long sequences.
@ -66,6 +71,7 @@ def run_vllm(
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
) -> float: ) -> float:
from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -160,14 +166,37 @@ def run_hf(
return end - start return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import pipeline
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm(prompts, max_new_tokens=output_len)
end = time.perf_counter()
return end - start
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, tokenizer = AutoTokenizer.from_pretrained(
trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer, elapsed_time = run_vllm(requests, args.model, args.tokenizer,
@ -179,6 +208,9 @@ def main(args: argparse.Namespace):
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size, args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code) args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(prompt_len + output_len
@ -191,12 +223,21 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument("--backend",
type=str, type=str,
choices=["vllm", "hf"], choices=["vllm", "hf", "mii"],
default="vllm") default="vllm")
parser.add_argument("--dataset", parser.add_argument("--dataset",
type=str, type=str,
required=True, default=None,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
@ -231,6 +272,13 @@ if __name__ == "__main__":
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
if args.backend == "vllm": if args.backend == "vllm":
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
@ -240,7 +288,18 @@ if __name__ == "__main__":
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None: elif args.backend == "mii":
args.tokenizer = args.model if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
main(args) main(args)