# SPDX-License-Identifier: Apache-2.0 import argparse import json from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry from vllm.profiler.utils import TablePrinter, indent_string def flatten_entries(entry_cls, profile_dict: dict): entries_and_depth = [] def get_entries(node, curr_depth=0): entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) for child in node["children"]: get_entries( child, curr_depth=curr_depth + 1, ) for root in profile_dict: get_entries(root) return entries_and_depth if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--json-trace", type=str, required=True, help="json trace file output by " "examples/offline_inference/profiling.py") parser.add_argument("--phase", type=str, required=True, help="The phase to print the table for. This is either" "prefill or decode_n, where n is the decode step " "number") parser.add_argument("--table", type=str, choices=["summary", "model"], default="summary", help="Which table to print, the summary table or the " "layerwise model table") args = parser.parse_args() with open(args.json_trace) as f: profile_data = json.load(f) assert args.phase in profile_data, \ (f"Cannot find phase {args.phase} in profile data. Choose one among" f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa if args.table == "summary": entries_and_depths = flatten_entries( SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15) elif args.table == "model": entries_and_depths = flatten_entries( ModelStatsEntry, profile_data[args.phase]["model_stats"]) column_widths = dict(name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60) # indent entry names based on the depth entries = [] for entry, depth in entries_and_depths: entry.name = indent_string( entry.name, indent=depth, indent_style=lambda indent: "|" + "-" * indent + " ") entries.append(entry) TablePrinter(type(entries[0]), column_widths).print_table(entries)