[misc] CUDA Time Layerwise Profiler (#8337)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2024-10-17 10:36:09 -04:00 committed by GitHub
parent 390be74649
commit 9d30a056e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1390 additions and 4 deletions

View File

@ -184,6 +184,7 @@ steps:
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_profile.py --model facebook/opt-125m
- label: Prefix Caching Test # 9min
#mirror_hardwares: [amd]

282
examples/offline_profile.py Normal file
View File

@ -0,0 +1,282 @@
import inspect
import json
import os
import sys
from argparse import RawTextHelpFormatter
from dataclasses import asdict, dataclass
from typing import Optional
import torch
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.profiler import layerwise_profile
from vllm.utils import FlexibleArgumentParser
BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256
OUTPUT_LEN_DEFAULT = 2
@dataclass
class ProfileContext:
engine_args: EngineArgs
prompt_len: int
output_len: int
batch_size: int
save_chrome_traces_folder: Optional[str]
def get_dtype(dtype: str):
if dtype == "torch.float":
return torch.float
else:
return dtype
def run_profile(context: ProfileContext, csv_output: Optional[str],
json_output: Optional[str]):
print("Run profile with:")
for key, value in asdict(context).items():
print(f" {key} = {value}")
# Create sampling params
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=args.output_len,
ignore_eos=True)
# Create LLM
llm = LLM(**asdict(context.engine_args))
batch_size = context.batch_size
prompt_len = context.prompt_len
output_len = context.output_len
scheduler_config = llm.llm_engine.scheduler_config
max_model_len = llm.llm_engine.model_config.max_model_len
max_num_batched_tokens = scheduler_config.max_num_batched_tokens
max_num_seqs = scheduler_config.max_num_seqs
if batch_size * prompt_len > max_num_batched_tokens:
print(f"ERROR: chosen batch_size * prompt_len "
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
f"and therefore cannot be run in a single profile step, please "
f"choose a smaller batch size or prompt length, or increase "
f"--max-num-batched-tokens")
sys.exit(-1)
if batch_size >= max_num_seqs:
print(
f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
f"single profile step, please choose a smaller batch size")
sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len)
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
print(
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
f"{output_len} = {prompt_len + output_len}) is larger than the "
f"model's max_model_len ({max_model_len}), please choose a smaller "
f"prompt_len or output_len, or increase --max-model-len")
sys.exit(-1)
def add_requests():
for i in range(batch_size):
prompt_token_ids = torch.randint(
llm.llm_engine.model_config.get_vocab_size(),
size=(prompt_len, )).tolist()
llm.llm_engine.add_request(
request_id=f"seq{i}",
prompt={'prompt_token_ids': prompt_token_ids},
params=sampling_params)
def abort_requests():
for i in range(batch_size):
llm.llm_engine.abort_request(f"seq{i}")
# Warm up run
print("Warm up run ...")
add_requests()
llm.llm_engine.step() # Prefill
llm.llm_engine.step() # Decode
abort_requests()
print("Profile run ...")
add_requests()
with layerwise_profile() as prefill_prof:
llm.llm_engine.step() # First step is prefill
decode_profs = []
for x in range(args.output_len - 1):
with layerwise_profile() as decode_prof:
llm.llm_engine.step()
decode_profs.append(decode_prof)
decode_results_list = [prof.results for prof in decode_profs]
prefill_results = prefill_prof.results
has_decode = len(decode_results_list) > 0
LINE_WIDTH = 80
print("=" * LINE_WIDTH)
print(f"= Prefill Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
prefill_results.print_model_table()
if has_decode:
print()
print("=" * LINE_WIDTH)
print(f"= First Decode Step Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
decode_results_list[0].print_model_table()
print()
print("=" * LINE_WIDTH)
print(f"= Prefill Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
prefill_results.print_summary_table()
if has_decode:
print()
print("=" * LINE_WIDTH)
print(f"= First Decode Step Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
decode_results_list[0].print_summary_table()
if csv_output:
csv_filename_base = csv_output.rstrip(".csv")
prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv")
prefill_results.export_summary_stats_table_csv(
csv_filename_base + "_prefill_summary_table.csv")
if has_decode:
decode_results_list[0].export_model_stats_table_csv(\
csv_filename_base + "_decode_model_table.csv")
decode_results_list[0].export_summary_stats_table_csv(
csv_filename_base + "_decode_summary_table.csv")
if json_output:
cuda_devices = [
torch.cuda.get_device_properties(dev_idx)
for dev_idx in range(torch.cuda.device_count())
]
json_dict = {
"context": {
"python_version": f"{sys.version}",
"torch_version": f"{torch.__version__}",
"torch_cuda_version": f"{torch.version.cuda}",
"cuda_devices": f"{cuda_devices}",
**asdict(context)
},
"prefill": prefill_results.convert_stats_to_dict(),
}
if has_decode:
for idx, dr in enumerate(decode_results_list):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
for idx, dr in enumerate(decode_results_list[1:]):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
with open(json_output.rstrip(".json") + ".json", "w+") as f:
json.dump(json_dict, f, indent=2)
pass
if context.save_chrome_traces_folder is not None:
os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
prefill_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + "/prefill.json")
for idx, decode_prof in enumerate(decode_profs):
decode_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
print("Traces saved as prefill.json and decode_1.json, etc."
f" in folder {context.save_chrome_traces_folder}")
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="""
Profile a model
example:
```
python examples/offline_profile.py \\
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
--enforce-eager
```
then you can use various tools to analyze the json output
terminal ascii tables:
```
python tools/profiler/print_layerwise_table.py \\
--json-trace Llama31-8b-FP8.json --phase prefill --table summary
```
or create matplotlib stacked bar charts:
```
python tools/profiler/visualize_layerwise_profile.py \\
--json-trace Llama31-8b-FP8.json \\
--output-directory profile_breakdown --plot-metric pct_cuda_time
```
""",
formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--csv",
type=str,
default=None,
help="Export the results as multiple csv file. This should be the root "
"filename, will create <filename>_prefill_model_table.csv, "
"<filename>_prefill_summary_table.csv, "
"<filename>_decode_model_table.csv, and "
"<filename>_decode_summary_table.csv")
parser.add_argument(
"--json",
type=str,
default=None,
help="Export the results as a json file. This should be the filename")
parser.add_argument("--save-chrome-traces-folder",
type=str,
help="Save chrome traces for the prefill and decode "
"will save traces as prefill.json and decode_1.json, "
"etc. inside this folder")
parser.add_argument(
"--prompt-len",
type=int,
default=PROMPT_LEN_DEFAULT,
help=f"Length of the random prompt to use when profiling, all batched "
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
parser.add_argument("--batch-size",
type=int,
default=BATCH_SIZE_DEFAULT,
help=f"Number of requests to run as a single batch, "
f"default={BATCH_SIZE_DEFAULT}")
parser.add_argument(
"--output-len",
type=int,
default=OUTPUT_LEN_DEFAULT,
help="Number of llm steps to run (includes prefill and decode) "
"- default={OUTPUT_LEN_DEFAULT}")
EngineArgs.add_cli_args(parser)
args = parser.parse_args()
context = ProfileContext(
engine_args=EngineArgs.from_cli_args(args),
**{
k: v
for k, v in vars(args).items()
if k in inspect.signature(ProfileContext).parameters
})
run_profile(context, csv_output=args.csv, json_output=args.json)

View File

@ -0,0 +1,77 @@
import argparse
import json
from typing import Dict
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
from vllm.profiler.utils import TablePrinter, indent_string
def flatten_entries(entry_cls, profile_dict: Dict):
entries_and_depth = []
def get_entries(node, curr_depth=0):
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
for child in node["children"]:
get_entries(
child,
curr_depth=curr_depth + 1,
)
for root in profile_dict:
get_entries(root)
return entries_and_depth
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by "
"examples/offline_profile.py")
parser.add_argument("--phase",
type=str,
choices=["prefill", "decode_1"],
required=True,
help="The phase to print the table for.")
parser.add_argument("--table",
type=str,
choices=["summary", "model"],
default="summary",
help="Which table to print, the summary table or the "
"layerwise model table")
args = parser.parse_args()
with open(args.json_trace, "r") as f:
profile_data = json.load(f)
if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
invocations=15)
elif args.table == "model":
entries_and_depths = flatten_entries(
ModelStatsEntry, profile_data[args.phase]["model_stats"])
column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
pct_cuda_time=12,
trace=60)
# indent entry names based on the depth
entries = []
for entry, depth in entries_and_depths:
entry.name = indent_string(
entry.name,
indent=depth,
indent_style=lambda indent: "|" + "-" * indent + " ")
entries.append(entry)
TablePrinter(type(entries[0]), column_widths).print_table(entries)

View File

@ -0,0 +1,522 @@
import argparse
import copy
import json
import math
import os
from pathlib import Path
from typing import Any, List, Optional, Tuple
import matplotlib.pyplot as plt
import pandas as pd
## JSON parsing utils ####
def largest_dist_from_leaf(node: dict, depth: int = 0):
if len(node["children"]) == 0:
return depth
return max([
largest_dist_from_leaf(child, depth=depth + 1)
for child in node["children"]
])
def get_entries_at_depth(depth: int,
entries_and_traces: List[Tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=()):
# assert that the query is at kernel or module level
assert depth == -1 or depth == -2
if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
# The tree is not tall enough!
entries_and_traces.append((node["entry"], trace))
return
if largest_dist_from_leaf(node) == (abs(depth) - 1):
entries_and_traces.append((node["entry"], trace))
trace = (node["entry"]["name"], ) + trace
for child in node["children"]:
get_entries_at_depth(depth,
entries_and_traces,
child,
curr_depth=curr_depth + 1,
trace=trace)
def fold_nodes(root: dict, nodes_to_fold: List[str]):
stack: List[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node['entry']['name'] in nodes_to_fold:
node["children"] = []
continue
for child in node["children"]:
stack.append(child)
return root
## Operation name cleanup utils ####
def trim_string_back(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
if len(string) > 3:
string = string + "..."
return string
def shorten_plot_legend_strings(legend, max_char_len: int):
for t in legend.get_texts():
t.set_text(
trim_string_back(abbreviate_known_names(t.get_text()),
max_char_len))
def abbreviate_known_names(name: str) -> str:
abbreviations = {
"MergedColumnParallelLinear": "MCPLinear",
"QKVParallelLinear": "QKVPLinear",
"RowParallelLinear": "RPLinear",
"weight=": "w=",
"bfloat16": "bf16",
"float16": "f16",
}
for key, value in abbreviations.items():
name = name.replace(key, value)
return name
def attempt_to_make_names_unique(entries_and_traces):
names, non_unique_names = (set(), set())
def all_the_same(items) -> bool:
return all(i == items[0] for i in items)
for entry, _ in entries_and_traces:
if entry["name"] in names:
non_unique_names.add(entry["name"])
else:
names.add(entry["name"])
for name in non_unique_names:
entries_and_traces_with_name = [(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name]
zipped_traces = list(
zip(*[trace for _, trace in entries_and_traces_with_name]))
first_trace_difference = next(
(i for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)), None)
if first_trace_difference is None:
# can't create a unique name, leave them names as the
# are they will get aggregated by the pivot_table call
continue
for entry, trace in entries_and_traces_with_name:
entry["name"] = " <- ".join((entry["name"], ) +
trace[:first_trace_difference + 1])
## Operation grouping utils ####
'''
Group operations in the given dataframe by some high-level ops like,
- gemms
- attention
- rms_norm
etc.
'''
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
def is_attention_block(op_name: str):
if "flash_fwd" in op_name or \
"reshape_and_cache_flash_kernel" in op_name:
return True
def is_quant(op_name: str):
if "scaled_fp8_quant" in op_name or \
"scaled_int8_quant" in op_name:
return True
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
if "xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \
"void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name or \
"s16816gemm" in op_name:
return True
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
def is_mem_op(op_name: str):
return "memcpy" in op_name.lower() or \
"memset" in op_name.lower()
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()
# nccl ops
def is_nccl_op(op_name: str):
return "nccl" in op_name.lower()
def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and \
("all_reduce" in op_name.lower() or \
"allreduce" in op_name.lower())
def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and \
"gather" in op_name.lower()
def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and \
"broadcast" in op_name.lower()
# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
return "cross_device_reduce_1stage" in op_name
def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name
def is_custom_ar_all_reduce_unreg(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name
def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name
headers = list(trace_df)
ops = copy.deepcopy(headers)
attention_ops = list(filter(lambda x: is_attention_block(x), ops))
ops = list(filter(lambda x: x not in attention_ops, ops))
quant_ops = list(filter(lambda x: is_quant(x), ops))
ops = list(filter(lambda x: x not in quant_ops, ops))
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
ops = list(filter(lambda x: x not in gemm_ops, ops))
rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
ops = list(filter(lambda x: x not in rms_norm_ops, ops))
vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
ops = list(filter(lambda x: x not in vocab_embed_ops, ops))
mem_ops = list(filter(lambda x: is_mem_op(x), ops))
ops = list(filter(lambda x: x not in mem_ops, ops))
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
ops = list(filter(lambda x: x not in elementwise_ops, ops))
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_unreg_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
if len(attention_ops):
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
if len(vocab_embed_ops):
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
axis=1)
if len(mem_ops):
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
if len(elementwise_ops):
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
axis=1)
if len(nccl_all_reduce_ops):
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1)
if len(nccl_gather_ops):
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
axis=1)
if len(nccl_broadcast_ops):
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
"sum", axis=1)
if len(nccl_other_ops):
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
axis=1)
if len(cross_device_reduce_1stage_ops):
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
cross_device_reduce_1stage_ops].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df.drop(
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
return trace_df
## Data plotting utils ####
def plot_trace_df(traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None):
phases = traces_df['phase'].unique()
traces_df = traces_df.pivot_table(index="phase",
columns="name",
values=plot_metric,
aggfunc="sum")
traces_df = group_trace_by_operations(traces_df)
# Make the figure
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
# Draw the stacked bars
ops = list(traces_df)
bottom = [0] * len(phases)
for op in ops:
values = [traces_df[op][phase] for phase in phases]
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
ax.bar(phases, values, label=op, bottom=bottom)
bottom = [bottom[j] + values[j] for j in range(len(phases))]
# Write the values as text on the bars
for bar in ax.patches:
if bar.get_height() != 0:
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha='center',
color='w',
weight='bold',
size=5)
# Setup legend
handles, labels = plt.gca().get_legend_handles_labels()
legend = fig.legend(handles,
labels,
loc='center left',
bbox_to_anchor=(1, 1))
shorten_plot_legend_strings(legend, 50)
# Setup labels and title
plt.setp(ax.get_xticklabels(), rotation=90)
ax.set_ylabel(plot_metric)
plt.suptitle(plot_title)
plt.savefig(output, bbox_inches='tight')
print("Created: ", output)
def main(
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: List[str]):
def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame:
def get_entries_and_traces(key: str):
entries_and_traces: List[Tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
# Fold nodes in the traces as per user request. i.e. simply
# make the requested nodes leaf-nodes.
root = fold_nodes(root, json_nodes_to_fold)
get_entries_at_depth(depth, entries_and_traces, root)
return entries_and_traces
def keep_only_top_entries(df: pd.DataFrame,
metric: str,
top_k: int = 9) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
["name"]] = "others"
return df
# Get data for each key
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
# Attempt some cleanup
if make_names_unique:
for trace in traces:
attempt_to_make_names_unique(trace)
# To pandas dataframe
trace_dfs = list(
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
traces))
# Respect top_k
if top_k:
trace_dfs = list(
map(
lambda trace_df: keep_only_top_entries(
trace_df, "cuda_time_us", top_k), trace_dfs))
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
# Add a derived metric `cuda_time_ms`
traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
traces_df = traces_df.fillna(0)
return traces_df
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get('sparsity', None)
return (f"{context['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"OutputLen={context['output_len']},"
f"NumGpus={context['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}")
profile_json = None
with open(json_trace, "r") as f:
profile_json = json.load(f)
assert profile_json is not None
# Get all `llm.generate.step()` profile
step_traces = list(profile_json.keys())
assert (step_traces[0] == 'context')
step_traces = step_traces[1:] # have only prefill and decodes
prefills = list(filter(lambda x: "prefill" in x, step_traces))
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
assert len(prefills) + len(all_decodes) == len(step_traces)
assert len(prefills) == 1
decodes = all_decodes[::args.step_plot_interval]
if decodes[-1] != all_decodes[-1]:
# Always have the last decode
decodes.append(all_decodes[-1])
prefill_traces = prepare_data(profile_json, prefills)
decode_traces = prepare_data(profile_json, decodes)
plot_title_suffix = make_plot_title_suffix(profile_json)
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
output_directory / Path("prefill.png"))
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by examples/offline_profile.py")
parser.add_argument("--output-directory",
type=str,
required=False,
help="Directory to output plots")
parser.add_argument("--level",
type=str,
default="module",
choices=["module", "kernel"])
parser.add_argument("--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.")
parser.add_argument("--fold-json-node",
nargs='+',
default=['Sampler', 'LogitsProcessor'],
help='Do not plot the children of these nodes. Let, \
the node represent the aggregate of all its \
children')
parser.add_argument("--plot-metric",
type=str,
default="cuda_time_ms",
help='Metric to plot. some options are cuda_time_ms, \
pct_cuda_time')
parser.add_argument(
"--step-plot-interval",
type=int,
default=4,
help="For every `step_plot_interval` steps, plot 1 step")
args = parser.parse_args()
# Prepare/Extract relevant args
make_names_unique = False
if args.level == "module":
depth = -2
make_names_unique = True
elif args.level == "kernel":
depth = -1
else:
raise Exception(f"Unexpected level value ({args.level})")
output_directory = args.output_directory if args.output_directory else Path(
args.json_trace).parent
if not os.path.exists(output_directory):
os.makedirs(output_directory)
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
make_names_unique, args.top_k, args.fold_json_node)

View File

@ -0,0 +1,5 @@
from .layerwise_profile import layerwise_profile
__all__ = [
"layerwise_profile",
]

View File

@ -0,0 +1,354 @@
import copy
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, TypeAlias, Union
import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
from torch.autograd.profiler import FunctionEvent
from torch.profiler import ProfilerActivity, profile
from vllm.profiler.utils import (TablePrinter, event_has_module,
event_is_torch_op, event_module_repr,
event_torch_op_stack_trace, indent_string)
@dataclass
class _ModuleTreeNode:
event: _ProfilerEvent
parent: Optional['_ModuleTreeNode'] = None
children: List['_ModuleTreeNode'] = field(default_factory=list)
trace: str = ""
@property
def is_leaf(self):
return (self.event.children is None or len(self.event.children) == 0)
@property
def is_torch_op(self):
return event_is_torch_op(self.event)
@property
def is_cuda(self):
return (self.event.tag == _EventType.Kineto
and self.event.typed[1].device_type == DeviceType.CUDA)
@dataclass
class SummaryStatsEntry:
name: str
cuda_time_us: float
pct_cuda_time: float
invocations: int
@dataclass
class ModelStatsEntry:
name: str
cpu_time_us: float
cuda_time_us: float
pct_cuda_time: float
trace: str
StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
@dataclass
class _StatsTreeNode:
entry: StatsEntry
children: List[StatsEntry]
parent: Optional[StatsEntry]
@dataclass
class LayerwiseProfileResults(profile):
_kineto_results: _ProfilerResult
_kineto_event_correlation_map: Dict[int,
List[_KinetoEvent]] = field(init=False)
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
_module_tree: List[_ModuleTreeNode] = field(init=False)
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
def __post_init__(self):
self._build_correlation_map()
self._build_module_tree()
self._build_stats_trees()
def print_model_table(self, column_widths: Dict[str, int] = None):
_column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
pct_cuda_time=12,
trace=60)
if column_widths:
_column_widths.update(**column_widths)
filtered_model_table = [
(depth, row)
for depth, row in self._flatten_stats_tree(self._model_stats_tree)
if row.cuda_time_us > 0 or row.cpu_time_us > 0
]
TablePrinter(ModelStatsEntry, _column_widths).print_table(
self._indent_row_names_based_on_depth(
filtered_model_table,
indent_style=lambda indent: "|" + "-" * indent + " "))
def print_summary_table(self, column_widths: Dict[str, int] = None):
_column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
invocations=15)
if column_widths:
_column_widths.update(**column_widths)
filtered_summary_table = [(depth, row)
for depth, row in self._flatten_stats_tree(
self._summary_stats_tree)
if row.cuda_time_us > 0]
TablePrinter(SummaryStatsEntry, _column_widths).print_table(
self._indent_row_names_based_on_depth(
filtered_summary_table,
indent_style=lambda indent: "|" + "-" * indent + " "))
def export_model_stats_table_csv(self, filename: str):
df = pd.DataFrame([
asdict(row)
for _, row in self._flatten_stats_tree(self._model_stats_tree)
])
df.to_csv(filename)
def export_summary_stats_table_csv(self, filename: str):
df = pd.DataFrame([
asdict(row)
for _, row in self._flatten_stats_tree(self._summary_stats_tree)
])
df.to_csv(filename)
def convert_stats_to_dict(self) -> str:
return {
"summary_stats":
self._convert_stats_tree_to_dict(self._summary_stats_tree),
"model_stats":
self._convert_stats_tree_to_dict(self._model_stats_tree)
}
@staticmethod
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
StatsEntry]],
indent_style: Union[Callable[[int],
str],
str] = " "):
indented_rows = []
for depth, row in depths_rows:
if row.cuda_time_us == 0:
continue
indented_row = copy.deepcopy(row)
indented_row.name = indent_string(indented_row.name, depth,
indent_style)
indented_rows.append(indented_row)
return indented_rows
def _build_correlation_map(self):
self._kineto_event_correlation_map = defaultdict(list)
for event in self._kineto_results.events():
self._kineto_event_correlation_map[event.correlation_id()].append(
event)
def _build_module_tree(self):
self._module_tree = []
event_tree = self._kineto_results.experimental_event_tree()
def _df_traversal(event: _ProfilerEvent,
curr_node: Optional[_ModuleTreeNode] = None):
# For the tensor parallel case for now only look at task 1
if event.start_tid != 1:
return
if event_has_module(event):
node = _ModuleTreeNode(event=event, parent=curr_node)
if curr_node:
curr_node.children.append(node)
else:
self._module_tree.append(node)
curr_node = node
is_leaf = (event.children is None or len(event.children) == 0)
if is_leaf and curr_node:
node = _ModuleTreeNode(
event=event,
parent=curr_node,
trace=event_torch_op_stack_trace(
event, until=lambda x: event_has_module(x)))
curr_node.children.append(node)
curr_node = node
for child in event.children:
_df_traversal(child, curr_node)
for root in event_tree:
_df_traversal(root)
def _get_kineto_gpu_event(self, node: _ModuleTreeNode):
if node.event.tag != _EventType.Kineto:
return None
correlated_kineto_events = self._kineto_event_correlation_map.get(
node.event.correlation_id, [])
iterator = (x for x in correlated_kineto_events
if x.device_type() == DeviceType.CUDA
and x.name() == node.event.name)
return next(iterator, None)
def _cumulative_cuda_time(self, node: _ModuleTreeNode):
'Return cuda time in microseconds'
def _cumulative_cuda_time_recursive(node: _ModuleTreeNode):
if node.is_leaf and (gpu_kineto_event :=
self._get_kineto_gpu_event(node)):
return gpu_kineto_event.duration_ns() / 1000.0
else:
cumulative_cuda_time = 0
for child in node.children:
cumulative_cuda_time += _cumulative_cuda_time_recursive(
child)
return cumulative_cuda_time
return _cumulative_cuda_time_recursive(node)
def _total_cuda_time(self):
return sum(
[self._cumulative_cuda_time(root) for root in self._module_tree])
def _build_stats_trees(self):
summary_dict: Dict[str, self.StatsTreeNode] = {}
total_cuda_time = self._total_cuda_time()
def pct_cuda_time(cuda_time_us):
return (cuda_time_us / total_cuda_time) * 100
def build_summary_stats_tree_df(
node: _ModuleTreeNode,
parent: Optional[_StatsTreeNode] = None,
summary_trace: Tuple[str] = ()):
if event_has_module(node.event):
name = event_module_repr(node.event)
cuda_time_us = self._cumulative_cuda_time(node)
elif (gpu_kineto_event := self._get_kineto_gpu_event(node)):
name = gpu_kineto_event.name()
cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0
else:
return None
summary_trace = summary_trace + (name, )
if summary_trace in summary_dict:
entry = summary_dict[summary_trace].entry
entry.cuda_time_us += cuda_time_us
entry.invocations += 1
entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us)
else:
new_node = _StatsTreeNode(entry=SummaryStatsEntry(
name=name,
cuda_time_us=cuda_time_us,
pct_cuda_time=pct_cuda_time(cuda_time_us),
invocations=1),
children=[],
parent=parent)
if parent:
parent.children.append(new_node)
summary_dict[summary_trace] = new_node
for child in node.children:
build_summary_stats_tree_df(child, summary_dict[summary_trace],
summary_trace)
return summary_dict[summary_trace]
self._summary_stats_tree = []
for root in self._module_tree:
self._summary_stats_tree.append(build_summary_stats_tree_df(root))
def build_model_stats_tree_df(node: _ModuleTreeNode,
parent: Optional[_StatsTreeNode] = None):
if event_has_module(node.event, ):
name = event_module_repr(node.event)
cuda_time_us = self._cumulative_cuda_time(node)
cpu_time_us = node.event.duration_time_ns / 1000
trace = ""
elif (gpu_kineto_event := self._get_kineto_gpu_event(node)):
name = gpu_kineto_event.name()
cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0
cpu_time_us = 0
trace = node.trace
else:
return None
new_node = _StatsTreeNode(entry=ModelStatsEntry(
name=name,
cpu_time_us=cpu_time_us,
cuda_time_us=cuda_time_us,
pct_cuda_time=pct_cuda_time(cuda_time_us),
trace=trace),
parent=parent,
children=[])
if parent:
parent.children.append(new_node)
for child in node.children:
build_model_stats_tree_df(child, new_node)
return new_node
self._model_stats_tree = []
for root in self._module_tree:
self._model_stats_tree.append(build_model_stats_tree_df(root))
def _flatten_stats_tree(
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
entries: List[Tuple[int, StatsEntry]] = []
def df_traversal(node: _StatsTreeNode, depth=0):
entries.append((depth, node.entry))
for child in node.children:
df_traversal(child, depth=depth + 1)
for root in tree:
df_traversal(root)
return entries
def _convert_stats_tree_to_dict(self,
tree: List[_StatsTreeNode]) -> List[Dict]:
root_dicts: List[Dict] = []
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
curr_json_list.append({
"entry": asdict(node.entry),
"children": []
})
for child in node.children:
df_traversal(child, curr_json_list[-1]["children"])
for root in tree:
df_traversal(root, root_dicts)
return root_dicts
class layerwise_profile(profile):
def __init__(self):
super().__init__(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
with_modules=True,
experimental_config=_ExperimentalConfig(verbose=True))
def __enter__(self):
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
self.results = LayerwiseProfileResults(self.profiler.kineto_results)

145
vllm/profiler/utils.py Normal file
View File

@ -0,0 +1,145 @@
import dataclasses
from typing import Callable, Dict, List, Type, Union
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
#
# String / Print Manipulation
#
def trim_string_front(string, width):
if len(string) > width:
offset = len(string) - width + 3
string = string[offset:]
if len(string) > 3:
string = "..." + string[3:]
return string
def trim_string_back(string, width):
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
if len(string) > 3:
string = string + "..."
return string
class TablePrinter:
def __init__(self, row_cls: Type[dataclasses.dataclass],
column_widths: Dict[str, int]):
self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames)
def print_table(self, rows: List[dataclasses.dataclass]):
self._print_header()
self._print_line()
for row in rows:
self._print_row(row)
def _print_header(self):
for i, f in enumerate(self.fieldnames):
last = (i == len(self.fieldnames) - 1)
col_width = self.column_widths[f]
print(trim_string_back(f, col_width).ljust(col_width),
end=" | " if not last else "\n")
def _print_row(self, row):
assert isinstance(row, self.row_cls)
for i, f in enumerate(self.fieldnames):
last = (i == len(self.fieldnames) - 1)
col_width = self.column_widths[f]
val = getattr(row, f)
val_str = ""
if isinstance(val, str):
val_str = trim_string_back(val, col_width).ljust(col_width)
elif type(val) in [float, int]:
val_str = f"{float(val):>.2f}".rjust(col_width)
else:
val_str = f"{val}".rjust(col_width)
print(val_str, end=" | " if not last else "\n")
def _print_line(self):
total_col_width = 0
for column_width in self.column_widths.values():
total_col_width += column_width
print("=" * (total_col_width + 3 * (len(self.column_widths) - 1)))
def indent_string(string: str,
indent: int,
indent_style: Union[Callable[[int], str], str] = " ") -> str:
if indent:
if isinstance(indent_style, str):
return indent_style * indent + string
else:
return indent_style(indent) + string
else:
return string
#
# _ProfilerEvent utils
#
def event_has_module(event: _ProfilerEvent) -> bool:
event_type, typed_event = event.typed
if event_type == _EventType.PyCall:
return typed_event.module is not None
return False
def event_is_torch_op(event: _ProfilerEvent) -> bool:
return event.tag == _EventType.TorchOp
def event_arg_repr(arg) -> str:
if arg is None or type(arg) in [float, int, bool, str]:
return f"{arg}"
elif isinstance(arg, list):
return f"[{', '.join([event_arg_repr(x) for x in arg])}]"
elif isinstance(arg, tuple):
return f"({', '.join([event_arg_repr(x) for x in arg])})"
else:
assert isinstance(arg,
_TensorMetadata), f"Unsupported type: {type(arg)}"
sizes_str = ', '.join([str(x) for x in arg.sizes])
return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]"
def event_torch_op_repr(event: _ProfilerEvent) -> str:
assert event.tag == _EventType.TorchOp
args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs])
return f"{event.name}({args_str})".replace("aten::", "")
def event_module_repr(event: _ProfilerEvent) -> str:
assert event_has_module(event)
module = event.typed[1].module
if module.parameters and len(module.parameters) > 0:
args_str = ', '.join(
[f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters])
return f"{module.cls_name}({args_str})"
else:
return module.cls_name
def event_torch_op_stack_trace(curr_event: _ProfilerEvent,
until: Callable[[_ProfilerEvent], bool]) -> str:
trace = ""
curr_event = curr_event.parent
while curr_event and not until(curr_event):
if event_is_torch_op(curr_event):
if len(trace) > 0:
trace += " <- "
trace += event_torch_op_repr(curr_event)
curr_event = curr_event.parent
return trace

View File

@ -1742,10 +1742,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
return [output]
class CUDAGraphRunner:
# NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph
class CUDAGraphRunner(nn.Module):
def __init__(self, model: nn.Module, backend_name: str,
attn_state: AttentionState, is_encoder_decoder_model: bool):
super().__init__()
self.model = model
self.backend_name = backend_name
self.attn_state = attn_state
@ -1892,9 +1895,6 @@ class CUDAGraphRunner:
return self.output_buffers
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.