[misc] Layerwise profile updates (#10242)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-12-16 13:14:57 -05:00 committed by GitHub
parent 2ca830dbaa
commit efbce85f4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 314 additions and 47 deletions

View File

@ -201,7 +201,7 @@ steps:
- python3 offline_inference_classification.py - python3 offline_inference_classification.py
- python3 offline_inference_embedding.py - python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py - python3 offline_inference_scoring.py
- python3 offline_profile.py --model facebook/opt-125m - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min - label: Prefix Caching Test # 9min
mirror_hardwares: [amd] mirror_hardwares: [amd]

View File

@ -4,9 +4,10 @@ import os
import sys import sys
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Optional from typing import Any, Dict, Generator, List, Optional, TypeAlias
import torch import torch
import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
@ -15,16 +16,21 @@ from vllm.utils import FlexibleArgumentParser
BATCH_SIZE_DEFAULT = 1 BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256 PROMPT_LEN_DEFAULT = 256
OUTPUT_LEN_DEFAULT = 2
@dataclass @dataclass
class ProfileContext: class ProfileContext:
engine_args: EngineArgs engine_args: EngineArgs
prompt_len: int prompt_len: int
output_len: int
batch_size: int batch_size: int
save_chrome_traces_folder: Optional[str]
# The profiler can run in 2 modes,
# 1. Run profiler for user specified num_steps
num_steps: Optional[int] = None
# 2. Run profiler until all requests complete
complete_num_requests_per_step: Optional[int] = None
save_chrome_traces_folder: Optional[str] = None
def get_dtype(dtype: str): def get_dtype(dtype: str):
@ -34,23 +40,155 @@ def get_dtype(dtype: str):
return dtype return dtype
OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
-> OutputLen_NumReqs_Map:
"""
Given the number of requests, batch_size, and the number of requests
that each engine-step should process, step_requests, determine the
output lengths of the requests such that step_request is honoured.
Example:
if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
then return,
{2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
32 requests should have output length 2,
32 requests should have output length 3,
32 requests should have output length 4,
31 requests should have output length 5,
1 request should have output length 6.
Args:
batch_size (int): Number of requests submitted for profile. This is
args.batch_size.
step_requests (List[int]): step_requests[i] is the number of requests
that the ith engine step should process.
Returns:
OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
number of requests required to have that output-length as values.
"""
ol_nr: OutputLen_NumReqs_Map = {}
# Number of request that are assigned an output-length
num_reqs_assigned: int = 0
num_steps: int = len(step_requests)
# sanity check. The first step (prefill-step), must process all requests.
assert step_requests[0] == batch_size
# Begin assignments from the last step.
output_length: int = num_steps
for num_requests_at_step in reversed(step_requests):
if num_reqs_assigned == batch_size:
break
assert num_reqs_assigned < batch_size
# Remove the number of requests that have been determined
# to participate in this step and beyond.
num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
assert num_reqs_unassigned_at_step >= 0
if num_reqs_unassigned_at_step > 0:
ol_nr[output_length] = num_reqs_unassigned_at_step
num_reqs_assigned += num_reqs_unassigned_at_step
output_length -= 1
# sanity checks.
assert sum(ol_nr.values()) == batch_size, \
("Number of requests in output-length assignment does not match "
f"batch-size.\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}")
# Check that the output-length is in [1, num-steps]. Output length must be
# at least 1 as all requests must participate in the prefill-step.
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
("Output lengths of requests should be in range "
f"[1, num-engine-steps].\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}")
return ol_nr
def determine_requests_per_step(context: ProfileContext) -> List[int]:
"""
Determine number of requests each engine step should process.
If context.num_steps is set, then all engine steps process the
same number of requests and the output list is of length
context.num_steps.
If context.complete_num_requests_per_step is set, then each decode step
processes fewer and fewer requests until there are no requests to process.
In this case, the output list is as big as the number of steps
required to process all requests.
Args:
context: ProfileContext object.
Returns:
List[int]: Number of requests to process for all engine-steps.
output[i], contains the number of requests that the ith step
should process.
"""
if context.num_steps:
# All requests must run until num_engine_steps. This implies
# that their output lengths must be equal to num_engine_steps.
return [context.batch_size] * context.num_steps
assert context.complete_num_requests_per_step and \
context.complete_num_requests_per_step > 0, \
(f"Expected a positive complete_num_requests_per_step argument."
f"Instead got {context.complete_num_requests_per_step}")
# We start dropping after the first decode step.
step_requests = [
context.batch_size, # prefill
context.batch_size, # decode
]
num_running_requests = context.batch_size
num_running_requests -= context.complete_num_requests_per_step
while num_running_requests > 0:
step_requests.append(num_running_requests)
num_running_requests -= context.complete_num_requests_per_step
if step_requests[-1] != 1:
# have 1 request running at the last step. This is often
# useful
step_requests.append(1)
return step_requests
def run_profile(context: ProfileContext, csv_output: Optional[str], def run_profile(context: ProfileContext, csv_output: Optional[str],
json_output: Optional[str]): json_output: Optional[str]):
print("Run profile with:") print("Run profile with:")
for key, value in asdict(context).items(): for key, value in asdict(context).items():
print(f" {key} = {value}") print(f" {key} = {value}")
requests_per_step: List[int] = determine_requests_per_step(context)
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
context.batch_size, requests_per_step)
num_steps_to_profile: int = len(requests_per_step)
max_output_len: int = max(ol_nr.keys())
assert max_output_len >= 1
# Create sampling params # Create sampling params
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(
top_p=0.95, temperature=0.8,
max_tokens=args.output_len, top_p=0.95,
ignore_eos=True) # max_tokens is set on a per-request basis.
max_tokens=None,
ignore_eos=True)
# Create LLM # Create LLM
llm = LLM(**asdict(context.engine_args)) llm = LLM(**asdict(context.engine_args))
batch_size = context.batch_size batch_size = context.batch_size
prompt_len = context.prompt_len prompt_len = context.prompt_len
output_len = context.output_len
scheduler_config = llm.llm_engine.scheduler_config scheduler_config = llm.llm_engine.scheduler_config
max_model_len = llm.llm_engine.model_config.max_model_len max_model_len = llm.llm_engine.model_config.max_model_len
@ -65,7 +203,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
f"choose a smaller batch size or prompt length, or increase " f"choose a smaller batch size or prompt length, or increase "
f"--max-num-batched-tokens") f"--max-num-batched-tokens")
sys.exit(-1) sys.exit(-1)
if batch_size >= max_num_seqs: if batch_size > max_num_seqs:
print( print(
f"ERROR: chosen batch_size ({batch_size}) is larger than " f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
@ -73,16 +211,26 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
sys.exit(-1) sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ", print("llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len) llm.llm_engine.model_config.max_model_len)
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
print( print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
f"ERROR: chosen prompt_len + output_len ({prompt_len} + " f"{max_output_len} = {prompt_len + max_output_len}) is larger "
f"{output_len} = {prompt_len + output_len}) is larger than the " f"than the model's max_model_len ({max_model_len}), please "
f"model's max_model_len ({max_model_len}), please choose a smaller " f"choose a smaller prompt_len or max_output_len, or increase "
f"prompt_len or output_len, or increase --max-model-len") f"--max-model-len")
sys.exit(-1) sys.exit(-1)
def add_requests(): def add_requests():
def get_output_len_generator() -> Generator[int, Any, Any]:
for output_len, num_reqs in ol_nr.items():
for _ in range(num_reqs):
yield output_len
output_len_generator = get_output_len_generator()
for i in range(batch_size): for i in range(batch_size):
sampling_params.max_tokens = next(output_len_generator)
assert isinstance(sampling_params.max_tokens, int)
prompt_token_ids = torch.randint( prompt_token_ids = torch.randint(
llm.llm_engine.model_config.get_vocab_size(), llm.llm_engine.model_config.get_vocab_size(),
size=(prompt_len, )).tolist() size=(prompt_len, )).tolist()
@ -110,8 +258,11 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
llm.llm_engine.step() # First step is prefill llm.llm_engine.step() # First step is prefill
decode_profs = [] decode_profs = []
for x in range(args.output_len - 1): for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
with layerwise_profile() as decode_prof: num_running_seqs = llm.llm_engine.scheduler[
0].get_num_unfinished_seq_groups()
with layerwise_profile(
num_running_seqs=num_running_seqs) as decode_prof:
llm.llm_engine.step() llm.llm_engine.step()
decode_profs.append(decode_prof) decode_profs.append(decode_prof)
@ -154,7 +305,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
decode_results_list[0].print_summary_table() decode_results_list[0].print_summary_table()
if csv_output: if csv_output:
csv_filename_base = csv_output.rstrip(".csv") csv_filename_base = csv_output[:-4] \
if csv_output.endswith('.csv') else csv_output
prefill_results.export_model_stats_table_csv( prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv") csv_filename_base + "_prefill_model_table.csv")
prefill_results.export_summary_stats_table_csv( prefill_results.export_summary_stats_table_csv(
@ -187,10 +339,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
for idx, dr in enumerate(decode_results_list): for idx, dr in enumerate(decode_results_list):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
for idx, dr in enumerate(decode_results_list[1:]): # Add .json to json_output filename if it doesn't exist already.
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() json_output_file = json_output if json_output.endswith(
'.json') else json_output + '.json'
with open(json_output.rstrip(".json") + ".json", "w+") as f: with open(json_output_file, "w+") as f:
json.dump(json_dict, f, indent=2) json.dump(json_dict, f, indent=2)
pass pass
@ -214,7 +366,7 @@ Profile a model
python examples/offline_profile.py \\ python examples/offline_profile.py \\
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
--enforce-eager --enforce-eager run_num_steps -n 2
``` ```
then you can use various tools to analyze the json output then you can use various tools to analyze the json output
@ -261,17 +413,41 @@ Profile a model
default=BATCH_SIZE_DEFAULT, default=BATCH_SIZE_DEFAULT,
help=f"Number of requests to run as a single batch, " help=f"Number of requests to run as a single batch, "
f"default={BATCH_SIZE_DEFAULT}") f"default={BATCH_SIZE_DEFAULT}")
parser.add_argument(
"--output-len", subparsers = parser.add_subparsers(dest="cmd")
run_num_steps_parser = subparsers.add_parser(
"run_num_steps",
help="This variation profiles n engine.step() invocations.")
run_num_steps_parser.add_argument(
'-n',
'--num-steps',
type=int, type=int,
default=OUTPUT_LEN_DEFAULT, help="Number of engine steps to profile.\n"
help="Number of llm steps to run (includes prefill and decode) " "Setting it to 1, profiles only the prefill step.\n"
"- default={OUTPUT_LEN_DEFAULT}") "Setting it to 2, profiles the prefill and first decode step\n"
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
"and so on ...")
run_to_completion_parser = subparsers.add_parser(
"run_to_completion",
help="This variation profiles all the engine.step() invocations"
"until the engine exhausts all submitted requests.")
run_to_completion_parser.add_argument(
'-n',
'--complete-num-requests-per-step',
type=int,
help=
"Complete complete_num_requests_per_step requests every decode step."
"For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
"the profiler is run for 6 engine steps, with the steps processing, "
"128, 128, 96, 64, 32, 1 requests respectively.\n"
"Note that we tack-on a one-request step at the end as it is often "
"useful.")
EngineArgs.add_cli_args(parser) EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
context = ProfileContext( context = ProfileContext(
engine_args=EngineArgs.from_cli_args(args), engine_args=EngineArgs.from_cli_args(args),
**{ **{

View File

@ -34,9 +34,10 @@ if __name__ == "__main__":
"examples/offline_profile.py") "examples/offline_profile.py")
parser.add_argument("--phase", parser.add_argument("--phase",
type=str, type=str,
choices=["prefill", "decode_1"],
required=True, required=True,
help="The phase to print the table for.") help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number")
parser.add_argument("--table", parser.add_argument("--table",
type=str, type=str,
choices=["summary", "model"], choices=["summary", "model"],
@ -49,6 +50,10 @@ if __name__ == "__main__":
with open(args.json_trace) as f: with open(args.json_trace) as f:
profile_data = json.load(f) profile_data = json.load(f)
assert args.phase in profile_data, \
(f"Cannot find phase {args.phase} in profile data. Choose one among"
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
if args.table == "summary": if args.table == "summary":
entries_and_depths = flatten_entries( entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) SummaryStatsEntry, profile_data[args.phase]["summary_stats"])

View File

@ -151,16 +151,31 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
"scaled_int8_quant" in op_name: "scaled_int8_quant" in op_name:
return True return True
# LoRA ops
def is_sgmv_shrink(op_name: str):
return "sgmv_shrink" in op_name
def is_sgmv_expand(op_name: str):
return "sgmv_expand" in op_name
def is_bgmv_shrink(op_name: str):
return "bgmv_shrink" in op_name
def is_bgmv_expand(op_name: str):
return "bgmv_expand" in op_name
def is_cutlass_gemm_op(op_name: str):
return "void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name
def is_gemm_op(op_name: str): def is_gemm_op(op_name: str):
if is_quant(op_name): if is_quant(op_name):
return False return False
if "xmma_gemm" in op_name or \ return is_cutlass_gemm_op(op_name) or \
"xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \ "gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \ "splitKreduce" in op_name or \
"void cutlass::Kernel" in op_name or \ "s16816gemm" in op_name
"void cutlass::device_kernel" in op_name or \
"s16816gemm" in op_name:
return True
def is_elementwise_op(op_name: str): def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name return "elementwise_kernel" in op_name
@ -211,6 +226,18 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
quant_ops = list(filter(lambda x: is_quant(x), ops)) quant_ops = list(filter(lambda x: is_quant(x), ops))
ops = list(filter(lambda x: x not in quant_ops, ops)) ops = list(filter(lambda x: x not in quant_ops, ops))
sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))
cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
ops = list(filter(lambda x: x not in gemm_ops, ops)) ops = list(filter(lambda x: x not in gemm_ops, ops))
@ -257,6 +284,24 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops): if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
if len(sgmv_shrink_ops):
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
axis=1)
if len(sgmv_expand_ops):
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
axis=1)
if len(bgmv_shrink_ops):
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
axis=1)
if len(bgmv_expand_ops):
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
axis=1)
if len(cutlass_gemm_ops):
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
axis=1)
if len(gemm_ops): if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops): if len(rms_norm_ops):
@ -296,7 +341,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1) axis=1)
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops + vocab_embed_ops + mem_ops + elementwise_ops +
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_other_ops + cross_device_reduce_1stage_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
@ -315,7 +362,14 @@ def plot_trace_df(traces_df: pd.DataFrame,
plot_title: str, plot_title: str,
output: Optional[Path] = None): output: Optional[Path] = None):
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df['phase_desc'].to_list()
assert all([desc == descs[0] for desc in descs])
return descs[0]
phases = traces_df['phase'].unique() phases = traces_df['phase'].unique()
phase_descs = [get_phase_description(traces_df, p) for p in phases]
traces_df = traces_df.pivot_table(index="phase", traces_df = traces_df.pivot_table(index="phase",
columns="name", columns="name",
values=plot_metric, values=plot_metric,
@ -324,7 +378,8 @@ def plot_trace_df(traces_df: pd.DataFrame,
traces_df = group_trace_by_operations(traces_df) traces_df = group_trace_by_operations(traces_df)
# Make the figure # Make the figure
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) fig_size_x = max(5, len(phases))
fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
# Draw the stacked bars # Draw the stacked bars
ops = list(traces_df) ops = list(traces_df)
@ -332,7 +387,7 @@ def plot_trace_df(traces_df: pd.DataFrame,
for op in ops: for op in ops:
values = [traces_df[op][phase] for phase in phases] values = [traces_df[op][phase] for phase in phases]
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
ax.bar(phases, values, label=op, bottom=bottom) ax.bar(phase_descs, values, label=op, bottom=bottom)
bottom = [bottom[j] + values[j] for j in range(len(phases))] bottom = [bottom[j] + values[j] for j in range(len(phases))]
# Write the values as text on the bars # Write the values as text on the bars
@ -390,6 +445,14 @@ def main(
["name"]] = "others" ["name"]] = "others"
return df return df
def get_phase_description(key: str) -> str:
num_running_seqs = profile_json[key]['metadata'][
'num_running_seqs']
if num_running_seqs is not None:
return f"{key}-seqs-{num_running_seqs}"
else:
return key
# Get data for each key # Get data for each key
traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
@ -413,6 +476,7 @@ def main(
# Fill in information about the step-keys # Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys): for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key trace_df['phase'] = step_key
trace_df['phase_desc'] = get_phase_description(step_key)
# Combine all data frames so they can be put in a single plot # Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs) traces_df = pd.concat(trace_dfs)
@ -426,12 +490,16 @@ def main(
def make_plot_title_suffix(profile_json: dict) -> str: def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"] context = profile_json["context"]
sparsity = context.get('sparsity', None) sparsity = context.get('sparsity', None)
return (f"{context['model']}\n" run_type = \
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
(f'Complete {context["complete_num_requests_per_step"]} per '
f'step; Run till completion')
return (f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, " f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, " f"PromptLen={context['prompt_len']}, "
f"OutputLen={context['output_len']}," f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"NumGpus={context['tensor_parallel_size']}" f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"{', Sparsity ' + sparsity if sparsity else ''}") f"Run Type: {run_type}")
profile_json = None profile_json = None
with open(json_trace) as f: with open(json_trace) as f:

View File

@ -72,6 +72,9 @@ class LayerwiseProfileResults(profile):
_model_stats_tree: List[_StatsTreeNode] = field(init=False) _model_stats_tree: List[_StatsTreeNode] = field(init=False)
_summary_stats_tree: List[_StatsTreeNode] = field(init=False) _summary_stats_tree: List[_StatsTreeNode] = field(init=False)
# profile metadata
num_running_seqs: Optional[int] = None
def __post_init__(self): def __post_init__(self):
self._build_correlation_map() self._build_correlation_map()
self._build_module_tree() self._build_module_tree()
@ -127,6 +130,9 @@ class LayerwiseProfileResults(profile):
def convert_stats_to_dict(self) -> str: def convert_stats_to_dict(self) -> str:
return { return {
"metadata": {
"num_running_seqs": self.num_running_seqs
},
"summary_stats": "summary_stats":
self._convert_stats_tree_to_dict(self._summary_stats_tree), self._convert_stats_tree_to_dict(self._summary_stats_tree),
"model_stats": "model_stats":
@ -338,7 +344,15 @@ class LayerwiseProfileResults(profile):
class layerwise_profile(profile): class layerwise_profile(profile):
def __init__(self): def __init__(self, num_running_seqs: Optional[int] = None):
"""
layerwise profile constructor.
Args:
num_running_seqs (Optional[int], optional): When given,
num_running_seqs will be passed to LayerProfileResults for metadata
update. Defaults to None.
"""
super().__init__( super().__init__(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True, record_shapes=True,
@ -346,9 +360,13 @@ class layerwise_profile(profile):
with_modules=True, with_modules=True,
experimental_config=_ExperimentalConfig(verbose=True)) experimental_config=_ExperimentalConfig(verbose=True))
self.num_running_seqs = num_running_seqs
def __enter__(self): def __enter__(self):
return super().__enter__() return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)
self.results = LayerwiseProfileResults(self.profiler.kineto_results) self.results = LayerwiseProfileResults(
self.profiler.kineto_results,
num_running_seqs=self.num_running_seqs)