vllm/tools/profiler/print_layerwise_table.py
2025-03-02 17:34:51 -08:00

84 lines
2.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import argparse
import json
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
from vllm.profiler.utils import TablePrinter, indent_string
def flatten_entries(entry_cls, profile_dict: dict):
entries_and_depth = []
def get_entries(node, curr_depth=0):
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
for child in node["children"]:
get_entries(
child,
curr_depth=curr_depth + 1,
)
for root in profile_dict:
get_entries(root)
return entries_and_depth
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by "
"examples/offline_inference/profiling.py")
parser.add_argument("--phase",
type=str,
required=True,
help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number")
parser.add_argument("--table",
type=str,
choices=["summary", "model"],
default="summary",
help="Which table to print, the summary table or the "
"layerwise model table")
args = parser.parse_args()
with open(args.json_trace) as f:
profile_data = json.load(f)
assert args.phase in profile_data, \
(f"Cannot find phase {args.phase} in profile data. Choose one among"
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
invocations=15)
elif args.table == "model":
entries_and_depths = flatten_entries(
ModelStatsEntry, profile_data[args.phase]["model_stats"])
column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
pct_cuda_time=12,
trace=60)
# indent entry names based on the depth
entries = []
for entry, depth in entries_and_depths:
entry.name = indent_string(
entry.name,
indent=depth,
indent_style=lambda indent: "|" + "-" * indent + " ")
entries.append(entry)
TablePrinter(type(entries[0]), column_widths).print_table(entries)