Add profile option to latency benchmark script (#1839)
This commit is contained in:
parent
f07c1ceaa5
commit
e74b1736a1
@ -12,7 +12,6 @@ from vllm import LLM, SamplingParams
|
|||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
# Process all the requests in a single batch if possible.
|
|
||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -21,7 +20,6 @@ def main(args: argparse.Namespace):
|
|||||||
quantization=args.quantization,
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
max_num_seqs=args.batch_size,
|
max_num_seqs=args.batch_size,
|
||||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
|
||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
)
|
)
|
||||||
@ -39,22 +37,31 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
def run_to_completion(profile: bool = False):
|
def run_to_completion(profile: bool = False):
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
with torch.profiler.profile(activities=[
|
||||||
start_time = time.perf_counter()
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
]) as p:
|
||||||
sampling_params=sampling_params,
|
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||||
use_tqdm=False)
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
end_time = time.perf_counter()
|
print(p.key_averages())
|
||||||
latency = end_time - start_time
|
else:
|
||||||
if profile:
|
start_time = time.perf_counter()
|
||||||
torch.cuda.cudart().cudaProfilerStop()
|
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||||
return latency
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
latency = end_time - start_time
|
||||||
|
return latency
|
||||||
|
|
||||||
print("Warming up...")
|
print("Warming up...")
|
||||||
run_to_completion(profile=False)
|
run_to_completion(profile=False)
|
||||||
|
|
||||||
|
if args.profile:
|
||||||
|
print("Profiling...")
|
||||||
|
run_to_completion(profile=True)
|
||||||
|
return
|
||||||
|
|
||||||
# Benchmark.
|
# Benchmark.
|
||||||
latencies = []
|
latencies = []
|
||||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||||
@ -97,5 +104,9 @@ if __name__ == '__main__':
|
|||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--profile',
|
||||||
|
action='store_true',
|
||||||
|
help='profile the generation process of a single batch')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user