[Misc] benchmark: Add option to set max concurrency (#9390)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
ae8b633ba3
commit
7dbe738d65
@ -398,6 +398,7 @@ async def benchmark(
|
|||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: List[str],
|
||||||
selected_percentiles: List[str],
|
selected_percentiles: List[str],
|
||||||
ignore_eos: bool,
|
ignore_eos: bool,
|
||||||
|
max_concurrency: Optional[int],
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@ -446,9 +447,25 @@ async def benchmark(
|
|||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
|
|
||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
|
print(f"Maximum request concurrency: {max_concurrency}")
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
|
# This can be used once the minimum Python version is 3.10 or higher,
|
||||||
|
# and it will simplify the code in limited_request_func.
|
||||||
|
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
|
# if max_concurrency else contextlib.nullcontext())
|
||||||
|
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
|
if max_concurrency else None)
|
||||||
|
|
||||||
|
async def limited_request_func(request_func_input, pbar):
|
||||||
|
if semaphore is None:
|
||||||
|
return await request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)
|
||||||
|
async with semaphore:
|
||||||
|
return await request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: List[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
@ -464,7 +481,7 @@ async def benchmark(
|
|||||||
ignore_eos=ignore_eos)
|
ignore_eos=ignore_eos)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
request_func(request_func_input=request_func_input,
|
limited_request_func(request_func_input=request_func_input,
|
||||||
pbar=pbar)))
|
pbar=pbar)))
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
@ -682,6 +699,7 @@ def main(args: argparse.Namespace):
|
|||||||
float(p) for p in args.metric_percentiles.split(",")
|
float(p) for p in args.metric_percentiles.split(",")
|
||||||
],
|
],
|
||||||
ignore_eos=args.ignore_eos,
|
ignore_eos=args.ignore_eos,
|
||||||
|
max_concurrency=args.max_concurrency,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
@ -711,13 +729,16 @@ def main(args: argparse.Namespace):
|
|||||||
# Traffic
|
# Traffic
|
||||||
result_json["request_rate"] = (
|
result_json["request_rate"] = (
|
||||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||||
|
result_json["max_concurrency"] = args.max_concurrency
|
||||||
|
|
||||||
# Merge with benchmark result
|
# Merge with benchmark result
|
||||||
result_json = {**result_json, **benchmark_result}
|
result_json = {**result_json, **benchmark_result}
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
base_model_id = model_id.split("/")[-1]
|
base_model_id = model_id.split("/")[-1]
|
||||||
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
|
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||||
|
if args.max_concurrency is not None else "")
|
||||||
|
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
||||||
if args.result_filename:
|
if args.result_filename:
|
||||||
file_name = args.result_filename
|
file_name = args.result_filename
|
||||||
if args.result_dir:
|
if args.result_dir:
|
||||||
@ -768,6 +789,19 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path to the sharegpt/sonnet dataset. "
|
help="Path to the sharegpt/sonnet dataset. "
|
||||||
"Or the huggingface dataset ID if using HF dataset.")
|
"Or the huggingface dataset ID if using HF dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-concurrency",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum number of concurrent requests. This can be used "
|
||||||
|
"to help simulate an environment where a higher level component "
|
||||||
|
"is enforcing a maximum number of concurrent requests. While the "
|
||||||
|
"--request-rate argument controls the rate at which requests are "
|
||||||
|
"initiated, this argument will control how many are actually allowed "
|
||||||
|
"to execute at a time. This means that when used in combination, the "
|
||||||
|
"actual request rate may be lower than specified with --request-rate, "
|
||||||
|
"if the server is not processing requests fast enough to keep up.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user