Save pytorch profiler output for latency benchmark (#1871)

* Save profiler output

* Apply feedback from code review
This commit is contained in:
Antoni Baum 2023-12-05 20:55:55 -08:00 committed by GitHub
parent 1d9b737e05
commit 05ff90b692
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,8 @@
"""Benchmark the latency of processing a single batch of requests.""" """Benchmark the latency of processing a single batch of requests."""
import argparse import argparse
import time import time
from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import torch import torch
@ -34,12 +36,15 @@ def main(args: argparse.Namespace):
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
def run_to_completion(profile: bool = False): def run_to_completion(profile_dir: Optional[str] = None):
if profile: if profile_dir:
with torch.profiler.profile(activities=[ with torch.profiler.profile(
torch.profiler.ProfilerActivity.CPU, activities=[
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU,
]) as p: torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
@ -54,11 +59,14 @@ def main(args: argparse.Namespace):
return latency return latency
print("Warming up...") print("Warming up...")
run_to_completion(profile=False) run_to_completion(profile_dir=None)
if args.profile: if args.profile:
print("Profiling...") profile_dir = args.profile_result_dir
run_to_completion(profile=True) if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir)
return return
# Benchmark. # Benchmark.
@ -107,5 +115,13 @@ if __name__ == '__main__':
'--profile', '--profile',
action='store_true', action='store_true',
help='profile the generation process of a single batch') help='profile the generation process of a single batch')
parser.add_argument(
'--profile-result-dir',
type=str,
default=None,
help=(
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)