[Benchmark] More accurate TPOT calc in benchmark_serving.py (#12288)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-01-21 21:46:14 -08:00 committed by GitHub
parent cbdc4ad5a5
commit 222a9dc350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 46 deletions

View File

@ -35,6 +35,7 @@ class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0.0 latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: List[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # List of inter-token latencies
@ -156,7 +157,7 @@ async def async_request_trt_llm(
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = timestamp - st
output.ttft = ttft output.ttft = ttft
# Decoding phase # Decoding phase
@ -245,6 +246,9 @@ async def async_request_openai_completions(
"logprobs": request_func_input.logprobs, "logprobs": request_func_input.logprobs,
"stream": True, "stream": True,
"ignore_eos": request_func_input.ignore_eos, "ignore_eos": request_func_input.ignore_eos,
"stream_options": {
"include_usage": True,
},
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
@ -256,7 +260,6 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
@ -271,15 +274,16 @@ async def async_request_openai_completions(
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ") "data: ")
if chunk == "[DONE]": if chunk != "[DONE]":
latency = time.perf_counter() - st
else:
data = json.loads(chunk) data = json.loads(chunk)
# NOTE: Some completion API might have a last # NOTE: Some completion API might have a last
# usage summary response without a token so we # usage summary response without a token so we
# want to check a token was generated # want to check a token was generated
if data["choices"][0]["text"]: if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if not first_chunk_received: if not first_chunk_received:
@ -293,7 +297,10 @@ async def async_request_openai_completions(
most_recent_timestamp) most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"] generated_text += text
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
@ -302,7 +309,7 @@ async def async_request_openai_completions(
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!") "This response will be marked as failed!")
output.generated_text = generated_text output.generated_text = generated_text
output.latency = latency output.latency = most_recent_timestamp - st
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
@ -342,6 +349,9 @@ async def async_request_openai_chat_completions(
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"ignore_eos": request_func_input.ignore_eos, "ignore_eos": request_func_input.ignore_eos,
"stream_options": {
"include_usage": True,
},
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
@ -368,17 +378,15 @@ async def async_request_openai_chat_completions(
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ") "data: ")
if chunk == "[DONE]": if chunk != "[DONE]":
latency = time.perf_counter() - st
else:
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
delta = data["choices"][0]["delta"] if choices := data.get("choices"):
if delta.get("content", None): content = choices[0]["delta"].get("content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = timestamp - st
output.ttft = ttft output.ttft = ttft
# Decoding phase # Decoding phase
@ -386,13 +394,16 @@ async def async_request_openai_chat_completions(
output.itl.append(timestamp - output.itl.append(timestamp -
most_recent_timestamp) most_recent_timestamp)
generated_text += delta["content"] generated_text += content
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = most_recent_timestamp - st
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False

View File

@ -25,6 +25,7 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import base64 import base64
import gc
import io import io
import json import json
import os import os
@ -423,7 +424,7 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: List[str], selected_percentile_metrics: List[str],
selected_percentiles: List[float], selected_percentiles: List[float],
gootput_config_dict: Dict[str, float], goodput_config_dict: Dict[str, float],
) -> Tuple[BenchmarkMetrics, List[int]]: ) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = [] actual_output_lens: List[int] = []
total_input = 0 total_input = 0
@ -436,19 +437,23 @@ def calculate_metrics(
e2els: List[float] = [] e2els: List[float] = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all output_len = outputs[i].output_tokens
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together if output_len is None:
# Note : this may inflate the output token count slightly # We use the tokenizer to count the number of output tokens
output_len = len( # for some serving backends instead of looking at
tokenizer(outputs[i].generated_text, # len(outputs[i].itl) since multiple output tokens may be
add_special_tokens=False).input_ids) # bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len) actual_output_lens.append(output_len)
total_input += input_requests[i][1] total_input += input_requests[i][1]
tpot = 0 tpot = 0
if output_len > 1: if output_len > 1:
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - latency_minus_ttft = outputs[i].latency - outputs[i].ttft
1) tpot = latency_minus_ttft / (output_len - 1)
tpots.append(tpot) tpots.append(tpot)
# Note: if output_len <= 1, we regard tpot as 0 for goodput # Note: if output_len <= 1, we regard tpot as 0 for goodput
all_tpots.append(tpot) all_tpots.append(tpot)
@ -459,21 +464,21 @@ def calculate_metrics(
else: else:
actual_output_lens.append(0) actual_output_lens.append(0)
if gootput_config_dict: if goodput_config_dict:
valid_metrics = [] valid_metrics = []
slo_values = [] slo_values = []
if "ttft" in gootput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(gootput_config_dict["ttft"] / slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION) MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in gootput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(gootput_config_dict["tpot"] / slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION) MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in gootput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(gootput_config_dict["e2el"] / slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION) MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
@ -537,7 +542,7 @@ async def benchmark(
selected_percentile_metrics: List[str], selected_percentile_metrics: List[str],
selected_percentiles: List[str], selected_percentiles: List[str],
ignore_eos: bool, ignore_eos: bool,
gootput_config_dict: Dict[str, float], goodput_config_dict: Dict[str, float],
max_concurrency: Optional[int], max_concurrency: Optional[int],
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
@ -661,7 +666,7 @@ async def benchmark(
tokenizer=tokenizer, tokenizer=tokenizer,
selected_percentile_metrics=selected_percentile_metrics, selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles, selected_percentiles=selected_percentiles,
gootput_config_dict=gootput_config_dict, goodput_config_dict=goodput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
@ -673,7 +678,7 @@ async def benchmark(
metrics.total_output)) metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput)) metrics.request_throughput))
if gootput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput)) metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
@ -688,7 +693,7 @@ async def benchmark(
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:":
metrics.request_goodput if gootput_config_dict else None, metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
@ -744,11 +749,11 @@ async def benchmark(
def check_goodput_args(args): def check_goodput_args(args):
# Check and parse goodput arguments # Check and parse goodput arguments
gootput_config_dict = {} goodput_config_dict = {}
VALID_NAMES = ["ttft", "tpot", "e2el"] VALID_NAMES = ["ttft", "tpot", "e2el"]
if args.goodput: if args.goodput:
gootput_config_dict = parse_goodput(args.goodput) goodput_config_dict = parse_goodput(args.goodput)
for slo_name, slo_val in gootput_config_dict.items(): for slo_name, slo_val in goodput_config_dict.items():
if slo_name not in VALID_NAMES: if slo_name not in VALID_NAMES:
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
@ -759,22 +764,22 @@ def check_goodput_args(args):
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative.")
return gootput_config_dict return goodput_config_dict
def parse_goodput(slo_pairs): def parse_goodput(slo_pairs):
gootput_config_dict = {} goodput_config_dict = {}
try: try:
for slo_pair in slo_pairs: for slo_pair in slo_pairs:
slo_name, slo_val = slo_pair.split(":") slo_name, slo_val = slo_pair.split(":")
gootput_config_dict[slo_name] = float(slo_val) goodput_config_dict[slo_name] = float(slo_val)
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " "Specify service level objectives for goodput as \"KEY:VALUE\" "
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds.") from err
return gootput_config_dict return goodput_config_dict
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
@ -874,7 +879,11 @@ def main(args: argparse.Namespace):
else: else:
raise ValueError(f"Unknown dataset: {args.dataset_name}") raise ValueError(f"Unknown dataset: {args.dataset_name}")
gootput_config_dict = check_goodput_args(args) goodput_config_dict = check_goodput_args(args)
# Avoid GC processing "static" data - reduce pause times.
gc.collect()
gc.freeze()
benchmark_result = asyncio.run( benchmark_result = asyncio.run(
benchmark( benchmark(
@ -896,7 +905,7 @@ def main(args: argparse.Namespace):
float(p) for p in args.metric_percentiles.split(",") float(p) for p in args.metric_percentiles.split(",")
], ],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
gootput_config_dict=gootput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
)) ))