import argparse import time from typing import List from tqdm import tqdm import numpy as np import torch from cacheflow.core.server import ( add_server_arguments, process_server_arguments, init_local_server_and_frontend_with_arguments) from cacheflow.sampling_params import SamplingParams def main(args: argparse.Namespace): server, frontend = init_local_server_and_frontend_with_arguments(args) sampling_params = SamplingParams( n=args.n, temperature=0.0 if args.use_beam_search else 1.0, top_p=1.0, use_beam_search=args.use_beam_search, stop_token_ids=set(), max_tokens=args.output_len, ) print(sampling_params) input_token_ids = [0] * args.input_len def profile_step(profile=False): if profile: torch.cuda.cudart().cudaProfilerStart() for _ in range(args.batch_size): dummy_prompt = "" frontend._add_query(dummy_prompt, input_token_ids, sampling_params) server.add_sequence_groups(frontend.get_inputs()) start_time = time.time() while True: server.step() if not server.has_unfinished_requests(): break end_time = time.time() latency = end_time - start_time if profile: torch.cuda.cudart().cudaProfilerStop() return latency print("Warm up step") profile_step() # Benchmark. latencies = [] for _ in tqdm(range(3), desc="Profile step"): latencies.append(profile_step()) print(f'Avg latency: {np.mean(latencies)} seconds') if __name__ == '__main__': parser = argparse.ArgumentParser( description='Benchmark the latency of decoding a single sentence.') parser = add_server_arguments(parser) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--n', type=int, default=1) parser.add_argument('--use-beam-search', action='store_true') args = parser.parse_args() args = process_server_arguments(args) args.max_num_batched_tokens = max( args.max_num_batched_tokens, args.batch_size * args.input_len) print(args) main(args)