Add JSON output support for benchmark_latency and benchmark_throughput (#4848)
This commit is contained in:
parent
6979ade384
commit
f09edd8a25
@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# run python-based benchmarks and upload the result to buildkite
|
||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||
python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
|
||||
bench_latency_exit_code=$?
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
@ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then
|
||||
exit $bench_serving_exit_code
|
||||
fi
|
||||
|
||||
/workspace/buildkite-agent artifact upload openai-*.json
|
||||
rm ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
/workspace/buildkite-agent artifact upload "*.json"
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -96,6 +97,16 @@ def main(args: argparse.Namespace):
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f'{percentage}% percentile latency: {percentile} seconds')
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -149,8 +160,8 @@ if __name__ == '__main__':
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||
'instead supported for common inference criteria.')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
@ -197,5 +208,10 @@ if __name__ == '__main__':
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the latency results in JSON format.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -242,6 +242,18 @@ def main(args: argparse.Namespace):
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
@ -353,6 +365,11 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
Loading…
x
Reference in New Issue
Block a user