diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e560cb1f..a0015ab1 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -12,7 +12,6 @@ from vllm import LLM, SamplingParams def main(args: argparse.Namespace): print(args) - # Process all the requests in a single batch if possible. # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM( @@ -21,7 +20,6 @@ def main(args: argparse.Namespace): quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, max_num_seqs=args.batch_size, - max_num_batched_tokens=args.batch_size * args.input_len, trust_remote_code=args.trust_remote_code, dtype=args.dtype, ) @@ -39,22 +37,31 @@ def main(args: argparse.Namespace): def run_to_completion(profile: bool = False): if profile: - torch.cuda.cudart().cudaProfilerStart() - start_time = time.perf_counter() - - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) - - end_time = time.perf_counter() - latency = end_time - start_time - if profile: - torch.cuda.cudart().cudaProfilerStop() - return latency + with torch.profiler.profile(activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ]) as p: + llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) + print(p.key_averages()) + else: + start_time = time.perf_counter() + llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) + end_time = time.perf_counter() + latency = end_time - start_time + return latency print("Warming up...") run_to_completion(profile=False) + if args.profile: + print("Profiling...") + run_to_completion(profile=True) + return + # Benchmark. latencies = [] for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): @@ -97,5 +104,9 @@ if __name__ == '__main__': 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument( + '--profile', + action='store_true', + help='profile the generation process of a single batch') args = parser.parse_args() main(args)