vllm/benchmarks/benchmark_text_completion.py

256 lines
8.7 KiB
Python
Raw Normal View History

import argparse
import logging
import os
import pickle
import time
from typing import List
from tqdm import tqdm
from transformers import AutoConfig
from benchmark.trace import generate_text_completion_requests
from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
logger = logging.getLogger(__name__)
def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
# Generate requests.
requests = generate_text_completion_requests(
args.dataset,
args.request_rate,
args.duration,
args.seed,
args.n1,
args.n2,
args.n3,
args.n4,
args.n6,
args.n2_beam,
args.n4_beam,
args.n6_beam,
args.n8_beam,
)
# Warm up.
logger.info('Warming up.')
num_warmup_requests = 8
warmup_input_len = 8
warmup_output_len = 32
warmup_sampling_params = SamplingParams(
n=1,
temperature=1.0,
top_p=0.99,
max_num_steps=warmup_output_len,
use_beam_search=False,
stop_token_ids=set(),
num_logprobs=0,
context_window_size=None,
)
for _ in range(num_warmup_requests):
frontend._add_query([0] * warmup_input_len, warmup_sampling_params)
server.add_sequence_groups(frontend.get_inputs())
while True:
server.step()
if not server.has_unfinished_requests():
break
# Start benchmarking.
logger.info('Start benchmarking.')
# Initialize tqdm.
pbar = tqdm(total=len(requests), desc='Finished requests')
finished = []
server.scheduler.reset_stats()
start_time = time.time()
while True:
now = time.time()
if args.timeout is not None and now - start_time > args.timeout:
logger.info('Timeout. Stop benchmarking.')
break
while requests:
if requests[0][0] <= now - start_time:
request_time, input_tokens, sampling_params = requests.pop(0)
frontend._add_query(
input_tokens, sampling_params, arrival_time=start_time + request_time)
else:
break
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
now = time.time()
for seq_group in updated_seq_groups:
if not seq_group.is_finished():
continue
arrival_time = seq_group.arrival_time
finish_time = now
for seq in seq_group.get_seqs():
seq_len = seq.get_len()
output_len = seq_len - seq.prompt_len
finished.append({
'group_id': seq_group.group_id,
'seq_id': seq.seq_id,
'arrival_time': arrival_time,
'finish_time': finish_time,
'prompt_len': seq.prompt_len,
'output_len': output_len,
})
pbar.update(1)
if not (requests or server.has_unfinished_requests()):
break
pbar.close()
logger.info('Finish benchmarking. Saving stats.')
server.scheduler.save_stats(args.output_dir)
with open(os.path.join(args.output_dir, 'sequences.pkl'), 'wb') as f:
pickle.dump(finished, f)
logger.info('Done.')
def get_model_name(model: str) -> str:
OPT_MODELS = [
'opt-125m',
'opt-350m',
'opt-1.3b',
'opt-2.7b',
'opt-6.7b',
'opt-13b',
'opt-30b',
'opt-66b',
'opt-175b',
]
for opt_model in OPT_MODELS:
if opt_model in model:
return opt_model
config = AutoConfig.from_pretrained(model)
assert config.model_type == 'llama'
hidden_size = config.hidden_size
if hidden_size == 4096:
return 'llama-7b'
elif hidden_size == 5120:
return 'llama-13b'
elif hidden_size == 6656:
return 'llama-30b'
elif hidden_size == 8192:
return 'llama-65b'
else:
raise ValueError(f'Unknown model: {model}')
def get_dataset_name(dataset: str) -> str:
if 'sharegpt' in dataset.lower():
return 'sharegpt'
elif 'alpaca' in dataset.lower():
return 'alpaca'
else:
raise ValueError(f'Unknown dataset: {dataset}')
def get_sampling_dir_name(
n1: float,
n2: float,
n3: float,
n4: float,
n6: float,
n2_beam: float,
n4_beam: float,
n6_beam: float,
n8_beam: float,
) -> str:
method = ''
if n1 > 0.0:
method = 'n1' if n1 == 1.0 else method + f'n1-{n1}-'
if n2 > 0.0:
method = 'n2' if n2 == 1.0 else method + f'n2-{n2}-'
if n3 > 0.0:
method = 'n3' if n3 == 1.0 else method + f'n3-{n3}-'
if n4 > 0.0:
method = 'n4' if n4 == 1.0 else method + f'n4-{n4}-'
if n6 > 0.0:
method = 'n6' if n6 == 1.0 else method + f'n6-{n6}-'
if n2_beam > 0.0:
method = 'n2-beam' if n2_beam == 1.0 else method + f'n2-beam-{n2_beam}-'
if n4_beam > 0.0:
method = 'n4-beam' if n4_beam == 1.0 else method + f'n4-beam-{n4_beam}-'
if n6_beam > 0.0:
method = 'n6-beam' if n6_beam == 1.0 else method + f'n6-beam-{n6_beam}-'
if n8_beam > 0.0:
method = 'n8-beam' if n8_beam == 1.0 else method + f'n8-beam-{n8_beam}-'
return method[:-1] if method.endswith('-') else method
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the performance on a series of requests.')
parser = add_server_arguments(parser)
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True)
parser.add_argument('--duration', type=int, help='duration in seconds', required=True)
parser.add_argument('--do-memory-analysis', action='store_true',
help='do memory analysis (This will lower the throughput. Use this only for analysis.)')
parser.add_argument('--timeout', type=int, help='time out in seconds', default=None)
parser.add_argument('--n1', type=float, help='ratio of requests with n=1', default=0.0)
parser.add_argument('--n2', type=float, help='ratio of requests with n=2', default=0.0)
parser.add_argument('--n3', type=float, help='ratio of requests with n=3', default=0.0)
parser.add_argument('--n4', type=float, help='ratio of requests with n=4', default=0.0)
parser.add_argument('--n6', type=float, help='ratio of requests with n=6', default=0.0)
parser.add_argument('--n2-beam', type=float, help='ratio of requests with n=2 & beam search', default=0.0)
parser.add_argument('--n4-beam', type=float, help='ratio of requests with n=4 & beam search', default=0.0)
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
args = parser.parse_args()
args = process_server_arguments(args)
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
raise ValueError('The ratios of requests must sum to 1.')
model_name = get_model_name(args.model)
dataset_name = get_dataset_name(args.dataset)
if 'opt' in model_name:
if 'opt' not in args.dataset.lower():
raise ValueError(f'OPT models can only be used with OPT datasets.')
elif 'llama' in model_name:
if 'llama' not in args.dataset.lower():
raise ValueError(f'Llama models can only be used with Llama datasets.')
dataset_name = 'sharegpt' if 'sharegpt' in args.dataset else 'alpaca'
sample_dir = get_sampling_dir_name(
args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam)
if args.output_dir is None:
args.output_dir = os.path.join(
'../exp',
dataset_name,
f'{model_name}-tp{args.tensor_parallel_size}',
sample_dir,
'cacheflow',
f'block{args.block_size}',
f'req-rate-{args.request_rate}',
f'seed{args.seed}',
f'duration-{args.duration}',
)
os.makedirs(args.output_dir, exist_ok=True)
# Set up logging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(args.output_dir, 'log.txt')),
],
)
logger.info(args)
main(args)