[misc] CUDA Time Layerwise Profiler (#8337)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
390be74649
commit
9d30a056e7
@ -184,6 +184,7 @@ steps:
|
||||
- python3 offline_inference_vision_language_multi_image.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
- python3 offline_profile.py --model facebook/opt-125m
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
#mirror_hardwares: [amd]
|
||||
|
282
examples/offline_profile.py
Normal file
282
examples/offline_profile.py
Normal file
@ -0,0 +1,282 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from argparse import RawTextHelpFormatter
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.profiler import layerwise_profile
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
BATCH_SIZE_DEFAULT = 1
|
||||
PROMPT_LEN_DEFAULT = 256
|
||||
OUTPUT_LEN_DEFAULT = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileContext:
|
||||
engine_args: EngineArgs
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
batch_size: int
|
||||
save_chrome_traces_folder: Optional[str]
|
||||
|
||||
|
||||
def get_dtype(dtype: str):
|
||||
if dtype == "torch.float":
|
||||
return torch.float
|
||||
else:
|
||||
return dtype
|
||||
|
||||
|
||||
def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
json_output: Optional[str]):
|
||||
print("Run profile with:")
|
||||
for key, value in asdict(context).items():
|
||||
print(f" {key} = {value}")
|
||||
|
||||
# Create sampling params
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True)
|
||||
|
||||
# Create LLM
|
||||
llm = LLM(**asdict(context.engine_args))
|
||||
batch_size = context.batch_size
|
||||
prompt_len = context.prompt_len
|
||||
output_len = context.output_len
|
||||
|
||||
scheduler_config = llm.llm_engine.scheduler_config
|
||||
max_model_len = llm.llm_engine.model_config.max_model_len
|
||||
max_num_batched_tokens = scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = scheduler_config.max_num_seqs
|
||||
|
||||
if batch_size * prompt_len > max_num_batched_tokens:
|
||||
print(f"ERROR: chosen batch_size * prompt_len "
|
||||
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
|
||||
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
|
||||
f"and therefore cannot be run in a single profile step, please "
|
||||
f"choose a smaller batch size or prompt length, or increase "
|
||||
f"--max-num-batched-tokens")
|
||||
sys.exit(-1)
|
||||
if batch_size >= max_num_seqs:
|
||||
print(
|
||||
f"ERROR: chosen batch_size ({batch_size}) is larger than "
|
||||
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
|
||||
f"single profile step, please choose a smaller batch size")
|
||||
sys.exit(-1)
|
||||
print("llm.llm_engine.model_config.max_model_len: ",
|
||||
llm.llm_engine.model_config.max_model_len)
|
||||
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
|
||||
print(
|
||||
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
|
||||
f"{output_len} = {prompt_len + output_len}) is larger than the "
|
||||
f"model's max_model_len ({max_model_len}), please choose a smaller "
|
||||
f"prompt_len or output_len, or increase --max-model-len")
|
||||
sys.exit(-1)
|
||||
|
||||
def add_requests():
|
||||
for i in range(batch_size):
|
||||
prompt_token_ids = torch.randint(
|
||||
llm.llm_engine.model_config.get_vocab_size(),
|
||||
size=(prompt_len, )).tolist()
|
||||
|
||||
llm.llm_engine.add_request(
|
||||
request_id=f"seq{i}",
|
||||
prompt={'prompt_token_ids': prompt_token_ids},
|
||||
params=sampling_params)
|
||||
|
||||
def abort_requests():
|
||||
for i in range(batch_size):
|
||||
llm.llm_engine.abort_request(f"seq{i}")
|
||||
|
||||
# Warm up run
|
||||
print("Warm up run ...")
|
||||
add_requests()
|
||||
llm.llm_engine.step() # Prefill
|
||||
llm.llm_engine.step() # Decode
|
||||
abort_requests()
|
||||
|
||||
print("Profile run ...")
|
||||
add_requests()
|
||||
|
||||
with layerwise_profile() as prefill_prof:
|
||||
llm.llm_engine.step() # First step is prefill
|
||||
|
||||
decode_profs = []
|
||||
for x in range(args.output_len - 1):
|
||||
with layerwise_profile() as decode_prof:
|
||||
llm.llm_engine.step()
|
||||
decode_profs.append(decode_prof)
|
||||
|
||||
decode_results_list = [prof.results for prof in decode_profs]
|
||||
prefill_results = prefill_prof.results
|
||||
has_decode = len(decode_results_list) > 0
|
||||
|
||||
LINE_WIDTH = 80
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= Prefill Model Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
prefill_results.print_model_table()
|
||||
|
||||
if has_decode:
|
||||
print()
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= First Decode Step Model Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
decode_results_list[0].print_model_table()
|
||||
|
||||
print()
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= Prefill Summary Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
prefill_results.print_summary_table()
|
||||
|
||||
if has_decode:
|
||||
print()
|
||||
print("=" * LINE_WIDTH)
|
||||
print(f"= First Decode Step Summary Table "
|
||||
f"(prompt_len={prompt_len}, batch_size={batch_size})")
|
||||
print("=" * LINE_WIDTH)
|
||||
print()
|
||||
decode_results_list[0].print_summary_table()
|
||||
|
||||
if csv_output:
|
||||
csv_filename_base = csv_output.rstrip(".csv")
|
||||
prefill_results.export_model_stats_table_csv(
|
||||
csv_filename_base + "_prefill_model_table.csv")
|
||||
prefill_results.export_summary_stats_table_csv(
|
||||
csv_filename_base + "_prefill_summary_table.csv")
|
||||
|
||||
if has_decode:
|
||||
decode_results_list[0].export_model_stats_table_csv(\
|
||||
csv_filename_base + "_decode_model_table.csv")
|
||||
decode_results_list[0].export_summary_stats_table_csv(
|
||||
csv_filename_base + "_decode_summary_table.csv")
|
||||
|
||||
if json_output:
|
||||
cuda_devices = [
|
||||
torch.cuda.get_device_properties(dev_idx)
|
||||
for dev_idx in range(torch.cuda.device_count())
|
||||
]
|
||||
|
||||
json_dict = {
|
||||
"context": {
|
||||
"python_version": f"{sys.version}",
|
||||
"torch_version": f"{torch.__version__}",
|
||||
"torch_cuda_version": f"{torch.version.cuda}",
|
||||
"cuda_devices": f"{cuda_devices}",
|
||||
**asdict(context)
|
||||
},
|
||||
"prefill": prefill_results.convert_stats_to_dict(),
|
||||
}
|
||||
|
||||
if has_decode:
|
||||
for idx, dr in enumerate(decode_results_list):
|
||||
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
|
||||
|
||||
for idx, dr in enumerate(decode_results_list[1:]):
|
||||
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
|
||||
|
||||
with open(json_output.rstrip(".json") + ".json", "w+") as f:
|
||||
json.dump(json_dict, f, indent=2)
|
||||
pass
|
||||
|
||||
if context.save_chrome_traces_folder is not None:
|
||||
os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
|
||||
prefill_prof.profiler.export_chrome_trace(
|
||||
context.save_chrome_traces_folder + "/prefill.json")
|
||||
for idx, decode_prof in enumerate(decode_profs):
|
||||
decode_prof.profiler.export_chrome_trace(
|
||||
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
|
||||
print("Traces saved as prefill.json and decode_1.json, etc."
|
||||
f" in folder {context.save_chrome_traces_folder}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="""
|
||||
Profile a model
|
||||
|
||||
example:
|
||||
```
|
||||
python examples/offline_profile.py \\
|
||||
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
|
||||
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
|
||||
--enforce-eager
|
||||
```
|
||||
|
||||
then you can use various tools to analyze the json output
|
||||
terminal ascii tables:
|
||||
```
|
||||
python tools/profiler/print_layerwise_table.py \\
|
||||
--json-trace Llama31-8b-FP8.json --phase prefill --table summary
|
||||
```
|
||||
or create matplotlib stacked bar charts:
|
||||
```
|
||||
python tools/profiler/visualize_layerwise_profile.py \\
|
||||
--json-trace Llama31-8b-FP8.json \\
|
||||
--output-directory profile_breakdown --plot-metric pct_cuda_time
|
||||
```
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Export the results as multiple csv file. This should be the root "
|
||||
"filename, will create <filename>_prefill_model_table.csv, "
|
||||
"<filename>_prefill_summary_table.csv, "
|
||||
"<filename>_decode_model_table.csv, and "
|
||||
"<filename>_decode_summary_table.csv")
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Export the results as a json file. This should be the filename")
|
||||
parser.add_argument("--save-chrome-traces-folder",
|
||||
type=str,
|
||||
help="Save chrome traces for the prefill and decode "
|
||||
"will save traces as prefill.json and decode_1.json, "
|
||||
"etc. inside this folder")
|
||||
parser.add_argument(
|
||||
"--prompt-len",
|
||||
type=int,
|
||||
default=PROMPT_LEN_DEFAULT,
|
||||
help=f"Length of the random prompt to use when profiling, all batched "
|
||||
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
|
||||
parser.add_argument("--batch-size",
|
||||
type=int,
|
||||
default=BATCH_SIZE_DEFAULT,
|
||||
help=f"Number of requests to run as a single batch, "
|
||||
f"default={BATCH_SIZE_DEFAULT}")
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=OUTPUT_LEN_DEFAULT,
|
||||
help="Number of llm steps to run (includes prefill and decode) "
|
||||
"- default={OUTPUT_LEN_DEFAULT}")
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context = ProfileContext(
|
||||
engine_args=EngineArgs.from_cli_args(args),
|
||||
**{
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if k in inspect.signature(ProfileContext).parameters
|
||||
})
|
||||
run_profile(context, csv_output=args.csv, json_output=args.json)
|
77
tools/profiler/print_layerwise_table.py
Normal file
77
tools/profiler/print_layerwise_table.py
Normal file
@ -0,0 +1,77 @@
|
||||
import argparse
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
|
||||
from vllm.profiler.utils import TablePrinter, indent_string
|
||||
|
||||
|
||||
def flatten_entries(entry_cls, profile_dict: Dict):
|
||||
entries_and_depth = []
|
||||
|
||||
def get_entries(node, curr_depth=0):
|
||||
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
|
||||
|
||||
for child in node["children"]:
|
||||
get_entries(
|
||||
child,
|
||||
curr_depth=curr_depth + 1,
|
||||
)
|
||||
|
||||
for root in profile_dict:
|
||||
get_entries(root)
|
||||
|
||||
return entries_and_depth
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by "
|
||||
"examples/offline_profile.py")
|
||||
parser.add_argument("--phase",
|
||||
type=str,
|
||||
choices=["prefill", "decode_1"],
|
||||
required=True,
|
||||
help="The phase to print the table for.")
|
||||
parser.add_argument("--table",
|
||||
type=str,
|
||||
choices=["summary", "model"],
|
||||
default="summary",
|
||||
help="Which table to print, the summary table or the "
|
||||
"layerwise model table")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.json_trace, "r") as f:
|
||||
profile_data = json.load(f)
|
||||
|
||||
if args.table == "summary":
|
||||
entries_and_depths = flatten_entries(
|
||||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
|
||||
column_widths = dict(name=80,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
invocations=15)
|
||||
elif args.table == "model":
|
||||
entries_and_depths = flatten_entries(
|
||||
ModelStatsEntry, profile_data[args.phase]["model_stats"])
|
||||
column_widths = dict(name=60,
|
||||
cpu_time_us=12,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
trace=60)
|
||||
|
||||
# indent entry names based on the depth
|
||||
entries = []
|
||||
for entry, depth in entries_and_depths:
|
||||
entry.name = indent_string(
|
||||
entry.name,
|
||||
indent=depth,
|
||||
indent_style=lambda indent: "|" + "-" * indent + " ")
|
||||
entries.append(entry)
|
||||
|
||||
TablePrinter(type(entries[0]), column_widths).print_table(entries)
|
522
tools/profiler/visualize_layerwise_profile.py
Normal file
522
tools/profiler/visualize_layerwise_profile.py
Normal file
@ -0,0 +1,522 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
## JSON parsing utils ####
|
||||
|
||||
|
||||
def largest_dist_from_leaf(node: dict, depth: int = 0):
|
||||
if len(node["children"]) == 0:
|
||||
return depth
|
||||
return max([
|
||||
largest_dist_from_leaf(child, depth=depth + 1)
|
||||
for child in node["children"]
|
||||
])
|
||||
|
||||
|
||||
def get_entries_at_depth(depth: int,
|
||||
entries_and_traces: List[Tuple[Any, Any]],
|
||||
node: dict,
|
||||
curr_depth: int = 0,
|
||||
trace=()):
|
||||
# assert that the query is at kernel or module level
|
||||
assert depth == -1 or depth == -2
|
||||
|
||||
if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
|
||||
# The tree is not tall enough!
|
||||
entries_and_traces.append((node["entry"], trace))
|
||||
return
|
||||
|
||||
if largest_dist_from_leaf(node) == (abs(depth) - 1):
|
||||
entries_and_traces.append((node["entry"], trace))
|
||||
|
||||
trace = (node["entry"]["name"], ) + trace
|
||||
for child in node["children"]:
|
||||
get_entries_at_depth(depth,
|
||||
entries_and_traces,
|
||||
child,
|
||||
curr_depth=curr_depth + 1,
|
||||
trace=trace)
|
||||
|
||||
|
||||
def fold_nodes(root: dict, nodes_to_fold: List[str]):
|
||||
|
||||
stack: List[dict] = [root]
|
||||
while len(stack) != 0:
|
||||
node = stack.pop()
|
||||
if node['entry']['name'] in nodes_to_fold:
|
||||
node["children"] = []
|
||||
continue
|
||||
for child in node["children"]:
|
||||
stack.append(child)
|
||||
return root
|
||||
|
||||
|
||||
## Operation name cleanup utils ####
|
||||
|
||||
|
||||
def trim_string_back(string: str, width: int) -> str:
|
||||
if len(string) > width:
|
||||
offset = len(string) - width + 3
|
||||
string = string[:-offset]
|
||||
if len(string) > 3:
|
||||
string = string + "..."
|
||||
return string
|
||||
|
||||
|
||||
def shorten_plot_legend_strings(legend, max_char_len: int):
|
||||
for t in legend.get_texts():
|
||||
t.set_text(
|
||||
trim_string_back(abbreviate_known_names(t.get_text()),
|
||||
max_char_len))
|
||||
|
||||
|
||||
def abbreviate_known_names(name: str) -> str:
|
||||
abbreviations = {
|
||||
"MergedColumnParallelLinear": "MCPLinear",
|
||||
"QKVParallelLinear": "QKVPLinear",
|
||||
"RowParallelLinear": "RPLinear",
|
||||
"weight=": "w=",
|
||||
"bfloat16": "bf16",
|
||||
"float16": "f16",
|
||||
}
|
||||
for key, value in abbreviations.items():
|
||||
name = name.replace(key, value)
|
||||
return name
|
||||
|
||||
|
||||
def attempt_to_make_names_unique(entries_and_traces):
|
||||
names, non_unique_names = (set(), set())
|
||||
|
||||
def all_the_same(items) -> bool:
|
||||
return all(i == items[0] for i in items)
|
||||
|
||||
for entry, _ in entries_and_traces:
|
||||
if entry["name"] in names:
|
||||
non_unique_names.add(entry["name"])
|
||||
else:
|
||||
names.add(entry["name"])
|
||||
|
||||
for name in non_unique_names:
|
||||
entries_and_traces_with_name = [(entry, trace)
|
||||
for entry, trace in entries_and_traces
|
||||
if entry["name"] == name]
|
||||
|
||||
zipped_traces = list(
|
||||
zip(*[trace for _, trace in entries_and_traces_with_name]))
|
||||
first_trace_difference = next(
|
||||
(i for i, trace_eles in enumerate(zipped_traces)
|
||||
if not all_the_same(trace_eles)), None)
|
||||
|
||||
if first_trace_difference is None:
|
||||
# can't create a unique name, leave them names as the
|
||||
# are they will get aggregated by the pivot_table call
|
||||
continue
|
||||
|
||||
for entry, trace in entries_and_traces_with_name:
|
||||
entry["name"] = " <- ".join((entry["name"], ) +
|
||||
trace[:first_trace_difference + 1])
|
||||
|
||||
|
||||
## Operation grouping utils ####
|
||||
'''
|
||||
Group operations in the given dataframe by some high-level ops like,
|
||||
- gemms
|
||||
- attention
|
||||
- rms_norm
|
||||
etc.
|
||||
'''
|
||||
|
||||
|
||||
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
def is_rms_norm(op_name: str):
|
||||
if "rms_norm_kernel" in op_name:
|
||||
return True
|
||||
|
||||
def is_attention_block(op_name: str):
|
||||
if "flash_fwd" in op_name or \
|
||||
"reshape_and_cache_flash_kernel" in op_name:
|
||||
return True
|
||||
|
||||
def is_quant(op_name: str):
|
||||
if "scaled_fp8_quant" in op_name or \
|
||||
"scaled_int8_quant" in op_name:
|
||||
return True
|
||||
|
||||
def is_gemm_op(op_name: str):
|
||||
if is_quant(op_name):
|
||||
return False
|
||||
if "xmma_gemm" in op_name or \
|
||||
"gemv2T_kernel" in op_name or \
|
||||
"splitKreduce" in op_name or \
|
||||
"void cutlass::Kernel" in op_name or \
|
||||
"void cutlass::device_kernel" in op_name or \
|
||||
"s16816gemm" in op_name:
|
||||
return True
|
||||
|
||||
def is_elementwise_op(op_name: str):
|
||||
return "elementwise_kernel" in op_name
|
||||
|
||||
def is_mem_op(op_name: str):
|
||||
return "memcpy" in op_name.lower() or \
|
||||
"memset" in op_name.lower()
|
||||
|
||||
def is_vocab_embedding_op(op_name: str):
|
||||
return "vocabparallelembed" in op_name.lower()
|
||||
|
||||
# nccl ops
|
||||
def is_nccl_op(op_name: str):
|
||||
return "nccl" in op_name.lower()
|
||||
|
||||
def is_nccl_all_reduce(op_name: str):
|
||||
return is_nccl_op(op_name) and \
|
||||
("all_reduce" in op_name.lower() or \
|
||||
"allreduce" in op_name.lower())
|
||||
|
||||
def is_nccl_gather(op_name: str):
|
||||
return is_nccl_op(op_name) and \
|
||||
"gather" in op_name.lower()
|
||||
|
||||
def is_nccl_broadcast(op_name: str):
|
||||
return is_nccl_op(op_name) and \
|
||||
"broadcast" in op_name.lower()
|
||||
|
||||
# Reduce ops types
|
||||
def is_cross_device_reduce_1stage(op_name: str):
|
||||
return "cross_device_reduce_1stage" in op_name
|
||||
|
||||
def is_cross_device_reduce_2stage(op_name: str):
|
||||
return "cross_device_reduce_2stage" in op_name
|
||||
|
||||
def is_custom_ar_all_reduce_unreg(op_name: str):
|
||||
return "_C_custom_ar::all_reduce_unreg" in op_name
|
||||
|
||||
def is_reduce_kernel(op_name: str):
|
||||
return "reduce_kernel" in op_name
|
||||
|
||||
headers = list(trace_df)
|
||||
ops = copy.deepcopy(headers)
|
||||
|
||||
attention_ops = list(filter(lambda x: is_attention_block(x), ops))
|
||||
ops = list(filter(lambda x: x not in attention_ops, ops))
|
||||
|
||||
quant_ops = list(filter(lambda x: is_quant(x), ops))
|
||||
ops = list(filter(lambda x: x not in quant_ops, ops))
|
||||
|
||||
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in gemm_ops, ops))
|
||||
|
||||
rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
|
||||
ops = list(filter(lambda x: x not in rms_norm_ops, ops))
|
||||
|
||||
vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in vocab_embed_ops, ops))
|
||||
|
||||
mem_ops = list(filter(lambda x: is_mem_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in mem_ops, ops))
|
||||
|
||||
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in elementwise_ops, ops))
|
||||
|
||||
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
|
||||
|
||||
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
|
||||
|
||||
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
|
||||
|
||||
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
|
||||
|
||||
cross_device_reduce_1stage_ops = list(
|
||||
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
|
||||
|
||||
cross_device_reduce_2stage_ops = list(
|
||||
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
|
||||
|
||||
custom_ar_all_reduce_unreg_ops = list(
|
||||
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
|
||||
|
||||
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
|
||||
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
|
||||
|
||||
if len(attention_ops):
|
||||
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
|
||||
if len(quant_ops):
|
||||
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
|
||||
if len(gemm_ops):
|
||||
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
|
||||
if len(rms_norm_ops):
|
||||
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
|
||||
if len(vocab_embed_ops):
|
||||
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
|
||||
axis=1)
|
||||
if len(mem_ops):
|
||||
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
|
||||
if len(elementwise_ops):
|
||||
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
if len(nccl_all_reduce_ops):
|
||||
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
|
||||
"sum", axis=1)
|
||||
if len(nccl_gather_ops):
|
||||
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
|
||||
axis=1)
|
||||
if len(nccl_broadcast_ops):
|
||||
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
|
||||
"sum", axis=1)
|
||||
if len(nccl_other_ops):
|
||||
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
if len(cross_device_reduce_1stage_ops):
|
||||
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
|
||||
cross_device_reduce_1stage_ops].agg("sum", axis=1)
|
||||
if len(cross_device_reduce_2stage_ops):
|
||||
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
|
||||
cross_device_reduce_2stage_ops].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_unreg_ops):
|
||||
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
|
||||
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
|
||||
if len(reduce_kernel_ops):
|
||||
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
trace_df.drop(
|
||||
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
|
||||
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
|
||||
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
|
||||
reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
return trace_df
|
||||
|
||||
|
||||
## Data plotting utils ####
|
||||
|
||||
|
||||
def plot_trace_df(traces_df: pd.DataFrame,
|
||||
plot_metric: str,
|
||||
plot_title: str,
|
||||
output: Optional[Path] = None):
|
||||
|
||||
phases = traces_df['phase'].unique()
|
||||
traces_df = traces_df.pivot_table(index="phase",
|
||||
columns="name",
|
||||
values=plot_metric,
|
||||
aggfunc="sum")
|
||||
|
||||
traces_df = group_trace_by_operations(traces_df)
|
||||
|
||||
# Make the figure
|
||||
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
|
||||
|
||||
# Draw the stacked bars
|
||||
ops = list(traces_df)
|
||||
bottom = [0] * len(phases)
|
||||
for op in ops:
|
||||
values = [traces_df[op][phase] for phase in phases]
|
||||
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
|
||||
ax.bar(phases, values, label=op, bottom=bottom)
|
||||
bottom = [bottom[j] + values[j] for j in range(len(phases))]
|
||||
|
||||
# Write the values as text on the bars
|
||||
for bar in ax.patches:
|
||||
if bar.get_height() != 0:
|
||||
ax.text(bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() / 2 + bar.get_y(),
|
||||
f"{round(bar.get_height(), 2)}",
|
||||
ha='center',
|
||||
color='w',
|
||||
weight='bold',
|
||||
size=5)
|
||||
|
||||
# Setup legend
|
||||
handles, labels = plt.gca().get_legend_handles_labels()
|
||||
legend = fig.legend(handles,
|
||||
labels,
|
||||
loc='center left',
|
||||
bbox_to_anchor=(1, 1))
|
||||
shorten_plot_legend_strings(legend, 50)
|
||||
|
||||
# Setup labels and title
|
||||
plt.setp(ax.get_xticklabels(), rotation=90)
|
||||
ax.set_ylabel(plot_metric)
|
||||
plt.suptitle(plot_title)
|
||||
|
||||
plt.savefig(output, bbox_inches='tight')
|
||||
print("Created: ", output)
|
||||
|
||||
|
||||
def main(
|
||||
json_trace: Path,
|
||||
output_directory: Path,
|
||||
depth: int, # Fetch/Plot operations at this depth of the Json tree
|
||||
plot_metric: str,
|
||||
make_names_unique: bool,
|
||||
top_k: int,
|
||||
json_nodes_to_fold: List[str]):
|
||||
|
||||
def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame:
|
||||
|
||||
def get_entries_and_traces(key: str):
|
||||
entries_and_traces: List[Tuple[Any, Any]] = []
|
||||
for root in profile_json[key]["summary_stats"]:
|
||||
# Fold nodes in the traces as per user request. i.e. simply
|
||||
# make the requested nodes leaf-nodes.
|
||||
root = fold_nodes(root, json_nodes_to_fold)
|
||||
get_entries_at_depth(depth, entries_and_traces, root)
|
||||
return entries_and_traces
|
||||
|
||||
def keep_only_top_entries(df: pd.DataFrame,
|
||||
metric: str,
|
||||
top_k: int = 9) -> pd.DataFrame:
|
||||
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
|
||||
["name"]] = "others"
|
||||
return df
|
||||
|
||||
# Get data for each key
|
||||
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
|
||||
|
||||
# Attempt some cleanup
|
||||
if make_names_unique:
|
||||
for trace in traces:
|
||||
attempt_to_make_names_unique(trace)
|
||||
|
||||
# To pandas dataframe
|
||||
trace_dfs = list(
|
||||
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
|
||||
traces))
|
||||
|
||||
# Respect top_k
|
||||
if top_k:
|
||||
trace_dfs = list(
|
||||
map(
|
||||
lambda trace_df: keep_only_top_entries(
|
||||
trace_df, "cuda_time_us", top_k), trace_dfs))
|
||||
|
||||
# Fill in information about the step-keys
|
||||
for trace_df, step_key in zip(trace_dfs, step_keys):
|
||||
trace_df['phase'] = step_key
|
||||
|
||||
# Combine all data frames so they can be put in a single plot
|
||||
traces_df = pd.concat(trace_dfs)
|
||||
|
||||
# Add a derived metric `cuda_time_ms`
|
||||
traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
|
||||
traces_df = traces_df.fillna(0)
|
||||
|
||||
return traces_df
|
||||
|
||||
def make_plot_title_suffix(profile_json: dict) -> str:
|
||||
context = profile_json["context"]
|
||||
sparsity = context.get('sparsity', None)
|
||||
return (f"{context['model']}\n"
|
||||
f"Batch={context['batch_size']}, "
|
||||
f"PromptLen={context['prompt_len']}, "
|
||||
f"OutputLen={context['output_len']},"
|
||||
f"NumGpus={context['tensor_parallel_size']}"
|
||||
f"{', Sparsity ' + sparsity if sparsity else ''}")
|
||||
|
||||
profile_json = None
|
||||
with open(json_trace, "r") as f:
|
||||
profile_json = json.load(f)
|
||||
assert profile_json is not None
|
||||
|
||||
# Get all `llm.generate.step()` profile
|
||||
step_traces = list(profile_json.keys())
|
||||
assert (step_traces[0] == 'context')
|
||||
step_traces = step_traces[1:] # have only prefill and decodes
|
||||
prefills = list(filter(lambda x: "prefill" in x, step_traces))
|
||||
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
|
||||
assert len(prefills) + len(all_decodes) == len(step_traces)
|
||||
assert len(prefills) == 1
|
||||
|
||||
decodes = all_decodes[::args.step_plot_interval]
|
||||
if decodes[-1] != all_decodes[-1]:
|
||||
# Always have the last decode
|
||||
decodes.append(all_decodes[-1])
|
||||
|
||||
prefill_traces = prepare_data(profile_json, prefills)
|
||||
decode_traces = prepare_data(profile_json, decodes)
|
||||
|
||||
plot_title_suffix = make_plot_title_suffix(profile_json)
|
||||
|
||||
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
|
||||
output_directory / Path("prefill.png"))
|
||||
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
|
||||
output_directory / Path("decode_steps.png"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by examples/offline_profile.py")
|
||||
parser.add_argument("--output-directory",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Directory to output plots")
|
||||
parser.add_argument("--level",
|
||||
type=str,
|
||||
default="module",
|
||||
choices=["module", "kernel"])
|
||||
parser.add_argument("--top-k",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Only graph the top `top_k` entries by time.")
|
||||
parser.add_argument("--fold-json-node",
|
||||
nargs='+',
|
||||
default=['Sampler', 'LogitsProcessor'],
|
||||
help='Do not plot the children of these nodes. Let, \
|
||||
the node represent the aggregate of all its \
|
||||
children')
|
||||
parser.add_argument("--plot-metric",
|
||||
type=str,
|
||||
default="cuda_time_ms",
|
||||
help='Metric to plot. some options are cuda_time_ms, \
|
||||
pct_cuda_time')
|
||||
parser.add_argument(
|
||||
"--step-plot-interval",
|
||||
type=int,
|
||||
default=4,
|
||||
help="For every `step_plot_interval` steps, plot 1 step")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare/Extract relevant args
|
||||
make_names_unique = False
|
||||
if args.level == "module":
|
||||
depth = -2
|
||||
make_names_unique = True
|
||||
elif args.level == "kernel":
|
||||
depth = -1
|
||||
else:
|
||||
raise Exception(f"Unexpected level value ({args.level})")
|
||||
|
||||
output_directory = args.output_directory if args.output_directory else Path(
|
||||
args.json_trace).parent
|
||||
|
||||
if not os.path.exists(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
|
||||
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
|
||||
make_names_unique, args.top_k, args.fold_json_node)
|
5
vllm/profiler/__init__.py
Normal file
5
vllm/profiler/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .layerwise_profile import layerwise_profile
|
||||
|
||||
__all__ = [
|
||||
"layerwise_profile",
|
||||
]
|
354
vllm/profiler/layerwise_profile.py
Normal file
354
vllm/profiler/layerwise_profile.py
Normal file
@ -0,0 +1,354 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
import pandas as pd
|
||||
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
|
||||
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
|
||||
from torch.autograd.profiler import FunctionEvent
|
||||
from torch.profiler import ProfilerActivity, profile
|
||||
|
||||
from vllm.profiler.utils import (TablePrinter, event_has_module,
|
||||
event_is_torch_op, event_module_repr,
|
||||
event_torch_op_stack_trace, indent_string)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ModuleTreeNode:
|
||||
event: _ProfilerEvent
|
||||
parent: Optional['_ModuleTreeNode'] = None
|
||||
children: List['_ModuleTreeNode'] = field(default_factory=list)
|
||||
trace: str = ""
|
||||
|
||||
@property
|
||||
def is_leaf(self):
|
||||
return (self.event.children is None or len(self.event.children) == 0)
|
||||
|
||||
@property
|
||||
def is_torch_op(self):
|
||||
return event_is_torch_op(self.event)
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return (self.event.tag == _EventType.Kineto
|
||||
and self.event.typed[1].device_type == DeviceType.CUDA)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryStatsEntry:
|
||||
name: str
|
||||
cuda_time_us: float
|
||||
pct_cuda_time: float
|
||||
invocations: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelStatsEntry:
|
||||
name: str
|
||||
cpu_time_us: float
|
||||
cuda_time_us: float
|
||||
pct_cuda_time: float
|
||||
trace: str
|
||||
|
||||
|
||||
StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StatsTreeNode:
|
||||
entry: StatsEntry
|
||||
children: List[StatsEntry]
|
||||
parent: Optional[StatsEntry]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerwiseProfileResults(profile):
|
||||
_kineto_results: _ProfilerResult
|
||||
_kineto_event_correlation_map: Dict[int,
|
||||
List[_KinetoEvent]] = field(init=False)
|
||||
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
|
||||
_module_tree: List[_ModuleTreeNode] = field(init=False)
|
||||
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
|
||||
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._build_correlation_map()
|
||||
self._build_module_tree()
|
||||
self._build_stats_trees()
|
||||
|
||||
def print_model_table(self, column_widths: Dict[str, int] = None):
|
||||
_column_widths = dict(name=60,
|
||||
cpu_time_us=12,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
trace=60)
|
||||
if column_widths:
|
||||
_column_widths.update(**column_widths)
|
||||
filtered_model_table = [
|
||||
(depth, row)
|
||||
for depth, row in self._flatten_stats_tree(self._model_stats_tree)
|
||||
if row.cuda_time_us > 0 or row.cpu_time_us > 0
|
||||
]
|
||||
TablePrinter(ModelStatsEntry, _column_widths).print_table(
|
||||
self._indent_row_names_based_on_depth(
|
||||
filtered_model_table,
|
||||
indent_style=lambda indent: "|" + "-" * indent + " "))
|
||||
|
||||
def print_summary_table(self, column_widths: Dict[str, int] = None):
|
||||
_column_widths = dict(name=80,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
invocations=15)
|
||||
if column_widths:
|
||||
_column_widths.update(**column_widths)
|
||||
filtered_summary_table = [(depth, row)
|
||||
for depth, row in self._flatten_stats_tree(
|
||||
self._summary_stats_tree)
|
||||
if row.cuda_time_us > 0]
|
||||
TablePrinter(SummaryStatsEntry, _column_widths).print_table(
|
||||
self._indent_row_names_based_on_depth(
|
||||
filtered_summary_table,
|
||||
indent_style=lambda indent: "|" + "-" * indent + " "))
|
||||
|
||||
def export_model_stats_table_csv(self, filename: str):
|
||||
df = pd.DataFrame([
|
||||
asdict(row)
|
||||
for _, row in self._flatten_stats_tree(self._model_stats_tree)
|
||||
])
|
||||
df.to_csv(filename)
|
||||
|
||||
def export_summary_stats_table_csv(self, filename: str):
|
||||
df = pd.DataFrame([
|
||||
asdict(row)
|
||||
for _, row in self._flatten_stats_tree(self._summary_stats_tree)
|
||||
])
|
||||
df.to_csv(filename)
|
||||
|
||||
def convert_stats_to_dict(self) -> str:
|
||||
return {
|
||||
"summary_stats":
|
||||
self._convert_stats_tree_to_dict(self._summary_stats_tree),
|
||||
"model_stats":
|
||||
self._convert_stats_tree_to_dict(self._model_stats_tree)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
|
||||
StatsEntry]],
|
||||
indent_style: Union[Callable[[int],
|
||||
str],
|
||||
str] = " "):
|
||||
indented_rows = []
|
||||
for depth, row in depths_rows:
|
||||
if row.cuda_time_us == 0:
|
||||
continue
|
||||
indented_row = copy.deepcopy(row)
|
||||
indented_row.name = indent_string(indented_row.name, depth,
|
||||
indent_style)
|
||||
indented_rows.append(indented_row)
|
||||
return indented_rows
|
||||
|
||||
def _build_correlation_map(self):
|
||||
self._kineto_event_correlation_map = defaultdict(list)
|
||||
for event in self._kineto_results.events():
|
||||
self._kineto_event_correlation_map[event.correlation_id()].append(
|
||||
event)
|
||||
|
||||
def _build_module_tree(self):
|
||||
self._module_tree = []
|
||||
event_tree = self._kineto_results.experimental_event_tree()
|
||||
|
||||
def _df_traversal(event: _ProfilerEvent,
|
||||
curr_node: Optional[_ModuleTreeNode] = None):
|
||||
|
||||
# For the tensor parallel case for now only look at task 1
|
||||
if event.start_tid != 1:
|
||||
return
|
||||
|
||||
if event_has_module(event):
|
||||
node = _ModuleTreeNode(event=event, parent=curr_node)
|
||||
if curr_node:
|
||||
curr_node.children.append(node)
|
||||
else:
|
||||
self._module_tree.append(node)
|
||||
curr_node = node
|
||||
|
||||
is_leaf = (event.children is None or len(event.children) == 0)
|
||||
if is_leaf and curr_node:
|
||||
node = _ModuleTreeNode(
|
||||
event=event,
|
||||
parent=curr_node,
|
||||
trace=event_torch_op_stack_trace(
|
||||
event, until=lambda x: event_has_module(x)))
|
||||
curr_node.children.append(node)
|
||||
curr_node = node
|
||||
|
||||
for child in event.children:
|
||||
_df_traversal(child, curr_node)
|
||||
|
||||
for root in event_tree:
|
||||
_df_traversal(root)
|
||||
|
||||
def _get_kineto_gpu_event(self, node: _ModuleTreeNode):
|
||||
if node.event.tag != _EventType.Kineto:
|
||||
return None
|
||||
correlated_kineto_events = self._kineto_event_correlation_map.get(
|
||||
node.event.correlation_id, [])
|
||||
iterator = (x for x in correlated_kineto_events
|
||||
if x.device_type() == DeviceType.CUDA
|
||||
and x.name() == node.event.name)
|
||||
return next(iterator, None)
|
||||
|
||||
def _cumulative_cuda_time(self, node: _ModuleTreeNode):
|
||||
'Return cuda time in microseconds'
|
||||
|
||||
def _cumulative_cuda_time_recursive(node: _ModuleTreeNode):
|
||||
if node.is_leaf and (gpu_kineto_event :=
|
||||
self._get_kineto_gpu_event(node)):
|
||||
return gpu_kineto_event.duration_ns() / 1000.0
|
||||
else:
|
||||
cumulative_cuda_time = 0
|
||||
for child in node.children:
|
||||
cumulative_cuda_time += _cumulative_cuda_time_recursive(
|
||||
child)
|
||||
return cumulative_cuda_time
|
||||
|
||||
return _cumulative_cuda_time_recursive(node)
|
||||
|
||||
def _total_cuda_time(self):
|
||||
return sum(
|
||||
[self._cumulative_cuda_time(root) for root in self._module_tree])
|
||||
|
||||
def _build_stats_trees(self):
|
||||
summary_dict: Dict[str, self.StatsTreeNode] = {}
|
||||
total_cuda_time = self._total_cuda_time()
|
||||
|
||||
def pct_cuda_time(cuda_time_us):
|
||||
return (cuda_time_us / total_cuda_time) * 100
|
||||
|
||||
def build_summary_stats_tree_df(
|
||||
node: _ModuleTreeNode,
|
||||
parent: Optional[_StatsTreeNode] = None,
|
||||
summary_trace: Tuple[str] = ()):
|
||||
|
||||
if event_has_module(node.event):
|
||||
name = event_module_repr(node.event)
|
||||
cuda_time_us = self._cumulative_cuda_time(node)
|
||||
elif (gpu_kineto_event := self._get_kineto_gpu_event(node)):
|
||||
name = gpu_kineto_event.name()
|
||||
cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0
|
||||
else:
|
||||
return None
|
||||
|
||||
summary_trace = summary_trace + (name, )
|
||||
if summary_trace in summary_dict:
|
||||
entry = summary_dict[summary_trace].entry
|
||||
entry.cuda_time_us += cuda_time_us
|
||||
entry.invocations += 1
|
||||
entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us)
|
||||
else:
|
||||
new_node = _StatsTreeNode(entry=SummaryStatsEntry(
|
||||
name=name,
|
||||
cuda_time_us=cuda_time_us,
|
||||
pct_cuda_time=pct_cuda_time(cuda_time_us),
|
||||
invocations=1),
|
||||
children=[],
|
||||
parent=parent)
|
||||
if parent:
|
||||
parent.children.append(new_node)
|
||||
summary_dict[summary_trace] = new_node
|
||||
|
||||
for child in node.children:
|
||||
build_summary_stats_tree_df(child, summary_dict[summary_trace],
|
||||
summary_trace)
|
||||
|
||||
return summary_dict[summary_trace]
|
||||
|
||||
self._summary_stats_tree = []
|
||||
for root in self._module_tree:
|
||||
self._summary_stats_tree.append(build_summary_stats_tree_df(root))
|
||||
|
||||
def build_model_stats_tree_df(node: _ModuleTreeNode,
|
||||
parent: Optional[_StatsTreeNode] = None):
|
||||
if event_has_module(node.event, ):
|
||||
name = event_module_repr(node.event)
|
||||
cuda_time_us = self._cumulative_cuda_time(node)
|
||||
cpu_time_us = node.event.duration_time_ns / 1000
|
||||
trace = ""
|
||||
elif (gpu_kineto_event := self._get_kineto_gpu_event(node)):
|
||||
name = gpu_kineto_event.name()
|
||||
cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0
|
||||
cpu_time_us = 0
|
||||
trace = node.trace
|
||||
else:
|
||||
return None
|
||||
|
||||
new_node = _StatsTreeNode(entry=ModelStatsEntry(
|
||||
name=name,
|
||||
cpu_time_us=cpu_time_us,
|
||||
cuda_time_us=cuda_time_us,
|
||||
pct_cuda_time=pct_cuda_time(cuda_time_us),
|
||||
trace=trace),
|
||||
parent=parent,
|
||||
children=[])
|
||||
if parent:
|
||||
parent.children.append(new_node)
|
||||
|
||||
for child in node.children:
|
||||
build_model_stats_tree_df(child, new_node)
|
||||
|
||||
return new_node
|
||||
|
||||
self._model_stats_tree = []
|
||||
for root in self._module_tree:
|
||||
self._model_stats_tree.append(build_model_stats_tree_df(root))
|
||||
|
||||
def _flatten_stats_tree(
|
||||
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
|
||||
entries: List[Tuple[int, StatsEntry]] = []
|
||||
|
||||
def df_traversal(node: _StatsTreeNode, depth=0):
|
||||
entries.append((depth, node.entry))
|
||||
for child in node.children:
|
||||
df_traversal(child, depth=depth + 1)
|
||||
|
||||
for root in tree:
|
||||
df_traversal(root)
|
||||
|
||||
return entries
|
||||
|
||||
def _convert_stats_tree_to_dict(self,
|
||||
tree: List[_StatsTreeNode]) -> List[Dict]:
|
||||
root_dicts: List[Dict] = []
|
||||
|
||||
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
|
||||
curr_json_list.append({
|
||||
"entry": asdict(node.entry),
|
||||
"children": []
|
||||
})
|
||||
for child in node.children:
|
||||
df_traversal(child, curr_json_list[-1]["children"])
|
||||
|
||||
for root in tree:
|
||||
df_traversal(root, root_dicts)
|
||||
|
||||
return root_dicts
|
||||
|
||||
|
||||
class layerwise_profile(profile):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
experimental_config=_ExperimentalConfig(verbose=True))
|
||||
|
||||
def __enter__(self):
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
self.results = LayerwiseProfileResults(self.profiler.kineto_results)
|
145
vllm/profiler/utils.py
Normal file
145
vllm/profiler/utils.py
Normal file
@ -0,0 +1,145 @@
|
||||
import dataclasses
|
||||
from typing import Callable, Dict, List, Type, Union
|
||||
|
||||
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
|
||||
|
||||
#
|
||||
# String / Print Manipulation
|
||||
#
|
||||
|
||||
|
||||
def trim_string_front(string, width):
|
||||
if len(string) > width:
|
||||
offset = len(string) - width + 3
|
||||
string = string[offset:]
|
||||
if len(string) > 3:
|
||||
string = "..." + string[3:]
|
||||
return string
|
||||
|
||||
|
||||
def trim_string_back(string, width):
|
||||
if len(string) > width:
|
||||
offset = len(string) - width + 3
|
||||
string = string[:-offset]
|
||||
if len(string) > 3:
|
||||
string = string + "..."
|
||||
return string
|
||||
|
||||
|
||||
class TablePrinter:
|
||||
|
||||
def __init__(self, row_cls: Type[dataclasses.dataclass],
|
||||
column_widths: Dict[str, int]):
|
||||
self.row_cls = row_cls
|
||||
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
|
||||
self.column_widths = column_widths
|
||||
assert set(self.column_widths.keys()) == set(self.fieldnames)
|
||||
|
||||
def print_table(self, rows: List[dataclasses.dataclass]):
|
||||
self._print_header()
|
||||
self._print_line()
|
||||
for row in rows:
|
||||
self._print_row(row)
|
||||
|
||||
def _print_header(self):
|
||||
for i, f in enumerate(self.fieldnames):
|
||||
last = (i == len(self.fieldnames) - 1)
|
||||
col_width = self.column_widths[f]
|
||||
print(trim_string_back(f, col_width).ljust(col_width),
|
||||
end=" | " if not last else "\n")
|
||||
|
||||
def _print_row(self, row):
|
||||
assert isinstance(row, self.row_cls)
|
||||
|
||||
for i, f in enumerate(self.fieldnames):
|
||||
last = (i == len(self.fieldnames) - 1)
|
||||
col_width = self.column_widths[f]
|
||||
val = getattr(row, f)
|
||||
|
||||
val_str = ""
|
||||
if isinstance(val, str):
|
||||
val_str = trim_string_back(val, col_width).ljust(col_width)
|
||||
elif type(val) in [float, int]:
|
||||
val_str = f"{float(val):>.2f}".rjust(col_width)
|
||||
else:
|
||||
val_str = f"{val}".rjust(col_width)
|
||||
print(val_str, end=" | " if not last else "\n")
|
||||
|
||||
def _print_line(self):
|
||||
total_col_width = 0
|
||||
for column_width in self.column_widths.values():
|
||||
total_col_width += column_width
|
||||
print("=" * (total_col_width + 3 * (len(self.column_widths) - 1)))
|
||||
|
||||
|
||||
def indent_string(string: str,
|
||||
indent: int,
|
||||
indent_style: Union[Callable[[int], str], str] = " ") -> str:
|
||||
if indent:
|
||||
if isinstance(indent_style, str):
|
||||
return indent_style * indent + string
|
||||
else:
|
||||
return indent_style(indent) + string
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
#
|
||||
# _ProfilerEvent utils
|
||||
#
|
||||
|
||||
|
||||
def event_has_module(event: _ProfilerEvent) -> bool:
|
||||
event_type, typed_event = event.typed
|
||||
if event_type == _EventType.PyCall:
|
||||
return typed_event.module is not None
|
||||
return False
|
||||
|
||||
|
||||
def event_is_torch_op(event: _ProfilerEvent) -> bool:
|
||||
return event.tag == _EventType.TorchOp
|
||||
|
||||
|
||||
def event_arg_repr(arg) -> str:
|
||||
if arg is None or type(arg) in [float, int, bool, str]:
|
||||
return f"{arg}"
|
||||
elif isinstance(arg, list):
|
||||
return f"[{', '.join([event_arg_repr(x) for x in arg])}]"
|
||||
elif isinstance(arg, tuple):
|
||||
return f"({', '.join([event_arg_repr(x) for x in arg])})"
|
||||
else:
|
||||
assert isinstance(arg,
|
||||
_TensorMetadata), f"Unsupported type: {type(arg)}"
|
||||
sizes_str = ', '.join([str(x) for x in arg.sizes])
|
||||
return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]"
|
||||
|
||||
|
||||
def event_torch_op_repr(event: _ProfilerEvent) -> str:
|
||||
assert event.tag == _EventType.TorchOp
|
||||
args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs])
|
||||
return f"{event.name}({args_str})".replace("aten::", "")
|
||||
|
||||
|
||||
def event_module_repr(event: _ProfilerEvent) -> str:
|
||||
assert event_has_module(event)
|
||||
module = event.typed[1].module
|
||||
if module.parameters and len(module.parameters) > 0:
|
||||
args_str = ', '.join(
|
||||
[f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters])
|
||||
return f"{module.cls_name}({args_str})"
|
||||
else:
|
||||
return module.cls_name
|
||||
|
||||
|
||||
def event_torch_op_stack_trace(curr_event: _ProfilerEvent,
|
||||
until: Callable[[_ProfilerEvent], bool]) -> str:
|
||||
trace = ""
|
||||
curr_event = curr_event.parent
|
||||
while curr_event and not until(curr_event):
|
||||
if event_is_torch_op(curr_event):
|
||||
if len(trace) > 0:
|
||||
trace += " <- "
|
||||
trace += event_torch_op_repr(curr_event)
|
||||
curr_event = curr_event.parent
|
||||
|
||||
return trace
|
@ -1742,10 +1742,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
return [output]
|
||||
|
||||
|
||||
class CUDAGraphRunner:
|
||||
# NOTE: this is nn.Module so the profiler can properly capture/group
|
||||
# kernels calls made within the graph
|
||||
class CUDAGraphRunner(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module, backend_name: str,
|
||||
attn_state: AttentionState, is_encoder_decoder_model: bool):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.backend_name = backend_name
|
||||
self.attn_state = attn_state
|
||||
@ -1892,9 +1895,6 @@ class CUDAGraphRunner:
|
||||
|
||||
return self.output_buffers
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
Loading…
x
Reference in New Issue
Block a user