[Benchmark] More accurate TPOT calc in benchmark_serving.py
(#12288)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
cbdc4ad5a5
commit
222a9dc350
@ -35,6 +35,7 @@ class RequestFuncOutput:
|
|||||||
generated_text: str = ""
|
generated_text: str = ""
|
||||||
success: bool = False
|
success: bool = False
|
||||||
latency: float = 0.0
|
latency: float = 0.0
|
||||||
|
output_tokens: int = 0
|
||||||
ttft: float = 0.0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
itl: List[float] = field(
|
itl: List[float] = field(
|
||||||
default_factory=list) # List of inter-token latencies
|
default_factory=list) # List of inter-token latencies
|
||||||
@ -156,7 +157,7 @@ async def async_request_trt_llm(
|
|||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = timestamp - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
@ -245,6 +246,9 @@ async def async_request_openai_completions(
|
|||||||
"logprobs": request_func_input.logprobs,
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"ignore_eos": request_func_input.ignore_eos,
|
"ignore_eos": request_func_input.ignore_eos,
|
||||||
|
"stream_options": {
|
||||||
|
"include_usage": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if request_func_input.extra_body:
|
if request_func_input.extra_body:
|
||||||
payload.update(request_func_input.extra_body)
|
payload.update(request_func_input.extra_body)
|
||||||
@ -256,7 +260,6 @@ async def async_request_openai_completions(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0.0
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
@ -271,15 +274,16 @@ async def async_request_openai_completions(
|
|||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
"data: ")
|
"data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk != "[DONE]":
|
||||||
latency = time.perf_counter() - st
|
|
||||||
else:
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
# NOTE: Some completion API might have a last
|
# NOTE: Some completion API might have a last
|
||||||
# usage summary response without a token so we
|
# usage summary response without a token so we
|
||||||
# want to check a token was generated
|
# want to check a token was generated
|
||||||
if data["choices"][0]["text"]:
|
if choices := data.get("choices"):
|
||||||
|
# Note that text could be empty here
|
||||||
|
# e.g. for special tokens
|
||||||
|
text = choices[0].get("text")
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if not first_chunk_received:
|
if not first_chunk_received:
|
||||||
@ -293,7 +297,10 @@ async def async_request_openai_completions(
|
|||||||
most_recent_timestamp)
|
most_recent_timestamp)
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
generated_text += data["choices"][0]["text"]
|
generated_text += text
|
||||||
|
elif usage := data.get("usage"):
|
||||||
|
output.output_tokens = usage.get(
|
||||||
|
"completion_tokens")
|
||||||
if first_chunk_received:
|
if first_chunk_received:
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
@ -302,7 +309,7 @@ async def async_request_openai_completions(
|
|||||||
"Never received a valid chunk to calculate TTFT."
|
"Never received a valid chunk to calculate TTFT."
|
||||||
"This response will be marked as failed!")
|
"This response will be marked as failed!")
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.latency = latency
|
output.latency = most_recent_timestamp - st
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
@ -342,6 +349,9 @@ async def async_request_openai_chat_completions(
|
|||||||
"max_completion_tokens": request_func_input.output_len,
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"ignore_eos": request_func_input.ignore_eos,
|
"ignore_eos": request_func_input.ignore_eos,
|
||||||
|
"stream_options": {
|
||||||
|
"include_usage": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if request_func_input.extra_body:
|
if request_func_input.extra_body:
|
||||||
payload.update(request_func_input.extra_body)
|
payload.update(request_func_input.extra_body)
|
||||||
@ -368,17 +378,15 @@ async def async_request_openai_chat_completions(
|
|||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
"data: ")
|
"data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk != "[DONE]":
|
||||||
latency = time.perf_counter() - st
|
|
||||||
else:
|
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
delta = data["choices"][0]["delta"]
|
if choices := data.get("choices"):
|
||||||
if delta.get("content", None):
|
content = choices[0]["delta"].get("content")
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = timestamp - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
@ -386,13 +394,16 @@ async def async_request_openai_chat_completions(
|
|||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp -
|
||||||
most_recent_timestamp)
|
most_recent_timestamp)
|
||||||
|
|
||||||
generated_text += delta["content"]
|
generated_text += content
|
||||||
|
elif usage := data.get("usage"):
|
||||||
|
output.output_tokens = usage.get(
|
||||||
|
"completion_tokens")
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = most_recent_timestamp - st
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
|
@ -25,6 +25,7 @@ On the client side, run:
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import gc
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -423,7 +424,7 @@ def calculate_metrics(
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: List[str],
|
||||||
selected_percentiles: List[float],
|
selected_percentiles: List[float],
|
||||||
gootput_config_dict: Dict[str, float],
|
goodput_config_dict: Dict[str, float],
|
||||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||||
actual_output_lens: List[int] = []
|
actual_output_lens: List[int] = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
@ -436,19 +437,23 @@ def calculate_metrics(
|
|||||||
e2els: List[float] = []
|
e2els: List[float] = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
# We use the tokenizer to count the number of output tokens for all
|
output_len = outputs[i].output_tokens
|
||||||
# serving backends instead of looking at len(outputs[i].itl) since
|
|
||||||
# multiple output tokens may be bundled together
|
if output_len is None:
|
||||||
# Note : this may inflate the output token count slightly
|
# We use the tokenizer to count the number of output tokens
|
||||||
output_len = len(
|
# for some serving backends instead of looking at
|
||||||
tokenizer(outputs[i].generated_text,
|
# len(outputs[i].itl) since multiple output tokens may be
|
||||||
add_special_tokens=False).input_ids)
|
# bundled together
|
||||||
|
# Note : this may inflate the output token count slightly
|
||||||
|
output_len = len(
|
||||||
|
tokenizer(outputs[i].generated_text,
|
||||||
|
add_special_tokens=False).input_ids)
|
||||||
actual_output_lens.append(output_len)
|
actual_output_lens.append(output_len)
|
||||||
total_input += input_requests[i][1]
|
total_input += input_requests[i][1]
|
||||||
tpot = 0
|
tpot = 0
|
||||||
if output_len > 1:
|
if output_len > 1:
|
||||||
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
|
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||||
1)
|
tpot = latency_minus_ttft / (output_len - 1)
|
||||||
tpots.append(tpot)
|
tpots.append(tpot)
|
||||||
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
||||||
all_tpots.append(tpot)
|
all_tpots.append(tpot)
|
||||||
@ -459,21 +464,21 @@ def calculate_metrics(
|
|||||||
else:
|
else:
|
||||||
actual_output_lens.append(0)
|
actual_output_lens.append(0)
|
||||||
|
|
||||||
if gootput_config_dict:
|
if goodput_config_dict:
|
||||||
valid_metrics = []
|
valid_metrics = []
|
||||||
slo_values = []
|
slo_values = []
|
||||||
|
|
||||||
if "ttft" in gootput_config_dict:
|
if "ttft" in goodput_config_dict:
|
||||||
valid_metrics.append(ttfts)
|
valid_metrics.append(ttfts)
|
||||||
slo_values.append(gootput_config_dict["ttft"] /
|
slo_values.append(goodput_config_dict["ttft"] /
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
if "tpot" in gootput_config_dict:
|
if "tpot" in goodput_config_dict:
|
||||||
valid_metrics.append(all_tpots)
|
valid_metrics.append(all_tpots)
|
||||||
slo_values.append(gootput_config_dict["tpot"] /
|
slo_values.append(goodput_config_dict["tpot"] /
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
if "e2el" in gootput_config_dict:
|
if "e2el" in goodput_config_dict:
|
||||||
valid_metrics.append(e2els)
|
valid_metrics.append(e2els)
|
||||||
slo_values.append(gootput_config_dict["e2el"] /
|
slo_values.append(goodput_config_dict["e2el"] /
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
|
|
||||||
for req_metric in zip(*valid_metrics):
|
for req_metric in zip(*valid_metrics):
|
||||||
@ -537,7 +542,7 @@ async def benchmark(
|
|||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: List[str],
|
||||||
selected_percentiles: List[str],
|
selected_percentiles: List[str],
|
||||||
ignore_eos: bool,
|
ignore_eos: bool,
|
||||||
gootput_config_dict: Dict[str, float],
|
goodput_config_dict: Dict[str, float],
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
@ -661,7 +666,7 @@ async def benchmark(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
selected_percentile_metrics=selected_percentile_metrics,
|
selected_percentile_metrics=selected_percentile_metrics,
|
||||||
selected_percentiles=selected_percentiles,
|
selected_percentiles=selected_percentiles,
|
||||||
gootput_config_dict=gootput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||||
@ -673,7 +678,7 @@ async def benchmark(
|
|||||||
metrics.total_output))
|
metrics.total_output))
|
||||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||||
metrics.request_throughput))
|
metrics.request_throughput))
|
||||||
if gootput_config_dict:
|
if goodput_config_dict:
|
||||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||||
metrics.request_goodput))
|
metrics.request_goodput))
|
||||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||||
@ -688,7 +693,7 @@ async def benchmark(
|
|||||||
"total_output_tokens": metrics.total_output,
|
"total_output_tokens": metrics.total_output,
|
||||||
"request_throughput": metrics.request_throughput,
|
"request_throughput": metrics.request_throughput,
|
||||||
"request_goodput:":
|
"request_goodput:":
|
||||||
metrics.request_goodput if gootput_config_dict else None,
|
metrics.request_goodput if goodput_config_dict else None,
|
||||||
"output_throughput": metrics.output_throughput,
|
"output_throughput": metrics.output_throughput,
|
||||||
"total_token_throughput": metrics.total_token_throughput,
|
"total_token_throughput": metrics.total_token_throughput,
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
@ -744,11 +749,11 @@ async def benchmark(
|
|||||||
|
|
||||||
def check_goodput_args(args):
|
def check_goodput_args(args):
|
||||||
# Check and parse goodput arguments
|
# Check and parse goodput arguments
|
||||||
gootput_config_dict = {}
|
goodput_config_dict = {}
|
||||||
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
||||||
if args.goodput:
|
if args.goodput:
|
||||||
gootput_config_dict = parse_goodput(args.goodput)
|
goodput_config_dict = parse_goodput(args.goodput)
|
||||||
for slo_name, slo_val in gootput_config_dict.items():
|
for slo_name, slo_val in goodput_config_dict.items():
|
||||||
if slo_name not in VALID_NAMES:
|
if slo_name not in VALID_NAMES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||||
@ -759,22 +764,22 @@ def check_goodput_args(args):
|
|||||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective value should be "
|
"The service level objective value should be "
|
||||||
"non-negative.")
|
"non-negative.")
|
||||||
return gootput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
def parse_goodput(slo_pairs):
|
def parse_goodput(slo_pairs):
|
||||||
gootput_config_dict = {}
|
goodput_config_dict = {}
|
||||||
try:
|
try:
|
||||||
for slo_pair in slo_pairs:
|
for slo_pair in slo_pairs:
|
||||||
slo_name, slo_val = slo_pair.split(":")
|
slo_name, slo_val = slo_pair.split(":")
|
||||||
gootput_config_dict[slo_name] = float(slo_val)
|
goodput_config_dict[slo_name] = float(slo_val)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
"Invalid format found for service level objectives. "
|
"Invalid format found for service level objectives. "
|
||||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||||
"pairs, where the key is a metric name, and the value is a "
|
"pairs, where the key is a metric name, and the value is a "
|
||||||
"number in milliseconds.") from err
|
"number in milliseconds.") from err
|
||||||
return gootput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
@ -874,7 +879,11 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||||
|
|
||||||
gootput_config_dict = check_goodput_args(args)
|
goodput_config_dict = check_goodput_args(args)
|
||||||
|
|
||||||
|
# Avoid GC processing "static" data - reduce pause times.
|
||||||
|
gc.collect()
|
||||||
|
gc.freeze()
|
||||||
|
|
||||||
benchmark_result = asyncio.run(
|
benchmark_result = asyncio.run(
|
||||||
benchmark(
|
benchmark(
|
||||||
@ -896,7 +905,7 @@ def main(args: argparse.Namespace):
|
|||||||
float(p) for p in args.metric_percentiles.split(",")
|
float(p) for p in args.metric_percentiles.split(",")
|
||||||
],
|
],
|
||||||
ignore_eos=args.ignore_eos,
|
ignore_eos=args.ignore_eos,
|
||||||
gootput_config_dict=gootput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user