[misc] Layerwise profile updates (#10242)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
2ca830dbaa
commit
efbce85f4d
@ -201,7 +201,7 @@ steps:
|
|||||||
- python3 offline_inference_classification.py
|
- python3 offline_inference_classification.py
|
||||||
- python3 offline_inference_embedding.py
|
- python3 offline_inference_embedding.py
|
||||||
- python3 offline_inference_scoring.py
|
- python3 offline_inference_scoring.py
|
||||||
- python3 offline_profile.py --model facebook/opt-125m
|
- python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||||
|
|
||||||
- label: Prefix Caching Test # 9min
|
- label: Prefix Caching Test # 9min
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
@ -4,9 +4,10 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Optional
|
from typing import Any, Dict, Generator, List, Optional, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
@ -15,16 +16,21 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
|
|
||||||
BATCH_SIZE_DEFAULT = 1
|
BATCH_SIZE_DEFAULT = 1
|
||||||
PROMPT_LEN_DEFAULT = 256
|
PROMPT_LEN_DEFAULT = 256
|
||||||
OUTPUT_LEN_DEFAULT = 2
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileContext:
|
class ProfileContext:
|
||||||
engine_args: EngineArgs
|
engine_args: EngineArgs
|
||||||
prompt_len: int
|
prompt_len: int
|
||||||
output_len: int
|
|
||||||
batch_size: int
|
batch_size: int
|
||||||
save_chrome_traces_folder: Optional[str]
|
|
||||||
|
# The profiler can run in 2 modes,
|
||||||
|
# 1. Run profiler for user specified num_steps
|
||||||
|
num_steps: Optional[int] = None
|
||||||
|
# 2. Run profiler until all requests complete
|
||||||
|
complete_num_requests_per_step: Optional[int] = None
|
||||||
|
|
||||||
|
save_chrome_traces_folder: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def get_dtype(dtype: str):
|
def get_dtype(dtype: str):
|
||||||
@ -34,23 +40,155 @@ def get_dtype(dtype: str):
|
|||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
|
OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
|
||||||
|
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
|
||||||
|
-> OutputLen_NumReqs_Map:
|
||||||
|
"""
|
||||||
|
Given the number of requests, batch_size, and the number of requests
|
||||||
|
that each engine-step should process, step_requests, determine the
|
||||||
|
output lengths of the requests such that step_request is honoured.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
|
||||||
|
then return,
|
||||||
|
{2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
|
||||||
|
32 requests should have output length 2,
|
||||||
|
32 requests should have output length 3,
|
||||||
|
32 requests should have output length 4,
|
||||||
|
31 requests should have output length 5,
|
||||||
|
1 request should have output length 6.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Number of requests submitted for profile. This is
|
||||||
|
args.batch_size.
|
||||||
|
step_requests (List[int]): step_requests[i] is the number of requests
|
||||||
|
that the ith engine step should process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
|
||||||
|
number of requests required to have that output-length as values.
|
||||||
|
"""
|
||||||
|
ol_nr: OutputLen_NumReqs_Map = {}
|
||||||
|
|
||||||
|
# Number of request that are assigned an output-length
|
||||||
|
num_reqs_assigned: int = 0
|
||||||
|
num_steps: int = len(step_requests)
|
||||||
|
|
||||||
|
# sanity check. The first step (prefill-step), must process all requests.
|
||||||
|
assert step_requests[0] == batch_size
|
||||||
|
|
||||||
|
# Begin assignments from the last step.
|
||||||
|
output_length: int = num_steps
|
||||||
|
for num_requests_at_step in reversed(step_requests):
|
||||||
|
if num_reqs_assigned == batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert num_reqs_assigned < batch_size
|
||||||
|
|
||||||
|
# Remove the number of requests that have been determined
|
||||||
|
# to participate in this step and beyond.
|
||||||
|
num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
|
||||||
|
assert num_reqs_unassigned_at_step >= 0
|
||||||
|
|
||||||
|
if num_reqs_unassigned_at_step > 0:
|
||||||
|
ol_nr[output_length] = num_reqs_unassigned_at_step
|
||||||
|
num_reqs_assigned += num_reqs_unassigned_at_step
|
||||||
|
|
||||||
|
output_length -= 1
|
||||||
|
|
||||||
|
# sanity checks.
|
||||||
|
assert sum(ol_nr.values()) == batch_size, \
|
||||||
|
("Number of requests in output-length assignment does not match "
|
||||||
|
f"batch-size.\n batch size {batch_size} - "
|
||||||
|
f"step requests {step_requests} - assignments {ol_nr}")
|
||||||
|
|
||||||
|
# Check that the output-length is in [1, num-steps]. Output length must be
|
||||||
|
# at least 1 as all requests must participate in the prefill-step.
|
||||||
|
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
|
||||||
|
("Output lengths of requests should be in range "
|
||||||
|
f"[1, num-engine-steps].\n batch size {batch_size} - "
|
||||||
|
f"step requests {step_requests} - assignments {ol_nr}")
|
||||||
|
|
||||||
|
return ol_nr
|
||||||
|
|
||||||
|
|
||||||
|
def determine_requests_per_step(context: ProfileContext) -> List[int]:
|
||||||
|
"""
|
||||||
|
Determine number of requests each engine step should process.
|
||||||
|
If context.num_steps is set, then all engine steps process the
|
||||||
|
same number of requests and the output list is of length
|
||||||
|
context.num_steps.
|
||||||
|
|
||||||
|
If context.complete_num_requests_per_step is set, then each decode step
|
||||||
|
processes fewer and fewer requests until there are no requests to process.
|
||||||
|
In this case, the output list is as big as the number of steps
|
||||||
|
required to process all requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: ProfileContext object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: Number of requests to process for all engine-steps.
|
||||||
|
output[i], contains the number of requests that the ith step
|
||||||
|
should process.
|
||||||
|
"""
|
||||||
|
if context.num_steps:
|
||||||
|
# All requests must run until num_engine_steps. This implies
|
||||||
|
# that their output lengths must be equal to num_engine_steps.
|
||||||
|
return [context.batch_size] * context.num_steps
|
||||||
|
|
||||||
|
assert context.complete_num_requests_per_step and \
|
||||||
|
context.complete_num_requests_per_step > 0, \
|
||||||
|
(f"Expected a positive complete_num_requests_per_step argument."
|
||||||
|
f"Instead got {context.complete_num_requests_per_step}")
|
||||||
|
|
||||||
|
# We start dropping after the first decode step.
|
||||||
|
step_requests = [
|
||||||
|
context.batch_size, # prefill
|
||||||
|
context.batch_size, # decode
|
||||||
|
]
|
||||||
|
|
||||||
|
num_running_requests = context.batch_size
|
||||||
|
num_running_requests -= context.complete_num_requests_per_step
|
||||||
|
while num_running_requests > 0:
|
||||||
|
step_requests.append(num_running_requests)
|
||||||
|
num_running_requests -= context.complete_num_requests_per_step
|
||||||
|
|
||||||
|
if step_requests[-1] != 1:
|
||||||
|
# have 1 request running at the last step. This is often
|
||||||
|
# useful
|
||||||
|
step_requests.append(1)
|
||||||
|
|
||||||
|
return step_requests
|
||||||
|
|
||||||
|
|
||||||
def run_profile(context: ProfileContext, csv_output: Optional[str],
|
def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||||
json_output: Optional[str]):
|
json_output: Optional[str]):
|
||||||
print("Run profile with:")
|
print("Run profile with:")
|
||||||
for key, value in asdict(context).items():
|
for key, value in asdict(context).items():
|
||||||
print(f" {key} = {value}")
|
print(f" {key} = {value}")
|
||||||
|
|
||||||
|
requests_per_step: List[int] = determine_requests_per_step(context)
|
||||||
|
|
||||||
|
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
|
||||||
|
context.batch_size, requests_per_step)
|
||||||
|
|
||||||
|
num_steps_to_profile: int = len(requests_per_step)
|
||||||
|
max_output_len: int = max(ol_nr.keys())
|
||||||
|
assert max_output_len >= 1
|
||||||
|
|
||||||
# Create sampling params
|
# Create sampling params
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
sampling_params = SamplingParams(
|
||||||
top_p=0.95,
|
temperature=0.8,
|
||||||
max_tokens=args.output_len,
|
top_p=0.95,
|
||||||
ignore_eos=True)
|
# max_tokens is set on a per-request basis.
|
||||||
|
max_tokens=None,
|
||||||
|
ignore_eos=True)
|
||||||
|
|
||||||
# Create LLM
|
# Create LLM
|
||||||
llm = LLM(**asdict(context.engine_args))
|
llm = LLM(**asdict(context.engine_args))
|
||||||
batch_size = context.batch_size
|
batch_size = context.batch_size
|
||||||
prompt_len = context.prompt_len
|
prompt_len = context.prompt_len
|
||||||
output_len = context.output_len
|
|
||||||
|
|
||||||
scheduler_config = llm.llm_engine.scheduler_config
|
scheduler_config = llm.llm_engine.scheduler_config
|
||||||
max_model_len = llm.llm_engine.model_config.max_model_len
|
max_model_len = llm.llm_engine.model_config.max_model_len
|
||||||
@ -65,7 +203,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
f"choose a smaller batch size or prompt length, or increase "
|
f"choose a smaller batch size or prompt length, or increase "
|
||||||
f"--max-num-batched-tokens")
|
f"--max-num-batched-tokens")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
if batch_size >= max_num_seqs:
|
if batch_size > max_num_seqs:
|
||||||
print(
|
print(
|
||||||
f"ERROR: chosen batch_size ({batch_size}) is larger than "
|
f"ERROR: chosen batch_size ({batch_size}) is larger than "
|
||||||
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
|
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
|
||||||
@ -73,16 +211,26 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
print("llm.llm_engine.model_config.max_model_len: ",
|
print("llm.llm_engine.model_config.max_model_len: ",
|
||||||
llm.llm_engine.model_config.max_model_len)
|
llm.llm_engine.model_config.max_model_len)
|
||||||
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
|
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
|
||||||
print(
|
print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
|
||||||
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
|
f"{max_output_len} = {prompt_len + max_output_len}) is larger "
|
||||||
f"{output_len} = {prompt_len + output_len}) is larger than the "
|
f"than the model's max_model_len ({max_model_len}), please "
|
||||||
f"model's max_model_len ({max_model_len}), please choose a smaller "
|
f"choose a smaller prompt_len or max_output_len, or increase "
|
||||||
f"prompt_len or output_len, or increase --max-model-len")
|
f"--max-model-len")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
def add_requests():
|
def add_requests():
|
||||||
|
|
||||||
|
def get_output_len_generator() -> Generator[int, Any, Any]:
|
||||||
|
for output_len, num_reqs in ol_nr.items():
|
||||||
|
for _ in range(num_reqs):
|
||||||
|
yield output_len
|
||||||
|
|
||||||
|
output_len_generator = get_output_len_generator()
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
sampling_params.max_tokens = next(output_len_generator)
|
||||||
|
assert isinstance(sampling_params.max_tokens, int)
|
||||||
|
|
||||||
prompt_token_ids = torch.randint(
|
prompt_token_ids = torch.randint(
|
||||||
llm.llm_engine.model_config.get_vocab_size(),
|
llm.llm_engine.model_config.get_vocab_size(),
|
||||||
size=(prompt_len, )).tolist()
|
size=(prompt_len, )).tolist()
|
||||||
@ -110,8 +258,11 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
llm.llm_engine.step() # First step is prefill
|
llm.llm_engine.step() # First step is prefill
|
||||||
|
|
||||||
decode_profs = []
|
decode_profs = []
|
||||||
for x in range(args.output_len - 1):
|
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
|
||||||
with layerwise_profile() as decode_prof:
|
num_running_seqs = llm.llm_engine.scheduler[
|
||||||
|
0].get_num_unfinished_seq_groups()
|
||||||
|
with layerwise_profile(
|
||||||
|
num_running_seqs=num_running_seqs) as decode_prof:
|
||||||
llm.llm_engine.step()
|
llm.llm_engine.step()
|
||||||
decode_profs.append(decode_prof)
|
decode_profs.append(decode_prof)
|
||||||
|
|
||||||
@ -154,7 +305,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
decode_results_list[0].print_summary_table()
|
decode_results_list[0].print_summary_table()
|
||||||
|
|
||||||
if csv_output:
|
if csv_output:
|
||||||
csv_filename_base = csv_output.rstrip(".csv")
|
csv_filename_base = csv_output[:-4] \
|
||||||
|
if csv_output.endswith('.csv') else csv_output
|
||||||
prefill_results.export_model_stats_table_csv(
|
prefill_results.export_model_stats_table_csv(
|
||||||
csv_filename_base + "_prefill_model_table.csv")
|
csv_filename_base + "_prefill_model_table.csv")
|
||||||
prefill_results.export_summary_stats_table_csv(
|
prefill_results.export_summary_stats_table_csv(
|
||||||
@ -187,10 +339,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
for idx, dr in enumerate(decode_results_list):
|
for idx, dr in enumerate(decode_results_list):
|
||||||
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
|
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
|
||||||
|
|
||||||
for idx, dr in enumerate(decode_results_list[1:]):
|
# Add .json to json_output filename if it doesn't exist already.
|
||||||
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
|
json_output_file = json_output if json_output.endswith(
|
||||||
|
'.json') else json_output + '.json'
|
||||||
with open(json_output.rstrip(".json") + ".json", "w+") as f:
|
with open(json_output_file, "w+") as f:
|
||||||
json.dump(json_dict, f, indent=2)
|
json.dump(json_dict, f, indent=2)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -214,7 +366,7 @@ Profile a model
|
|||||||
python examples/offline_profile.py \\
|
python examples/offline_profile.py \\
|
||||||
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
|
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
|
||||||
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
|
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
|
||||||
--enforce-eager
|
--enforce-eager run_num_steps -n 2
|
||||||
```
|
```
|
||||||
|
|
||||||
then you can use various tools to analyze the json output
|
then you can use various tools to analyze the json output
|
||||||
@ -261,17 +413,41 @@ Profile a model
|
|||||||
default=BATCH_SIZE_DEFAULT,
|
default=BATCH_SIZE_DEFAULT,
|
||||||
help=f"Number of requests to run as a single batch, "
|
help=f"Number of requests to run as a single batch, "
|
||||||
f"default={BATCH_SIZE_DEFAULT}")
|
f"default={BATCH_SIZE_DEFAULT}")
|
||||||
parser.add_argument(
|
|
||||||
"--output-len",
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
|
run_num_steps_parser = subparsers.add_parser(
|
||||||
|
"run_num_steps",
|
||||||
|
help="This variation profiles n engine.step() invocations.")
|
||||||
|
run_num_steps_parser.add_argument(
|
||||||
|
'-n',
|
||||||
|
'--num-steps',
|
||||||
type=int,
|
type=int,
|
||||||
default=OUTPUT_LEN_DEFAULT,
|
help="Number of engine steps to profile.\n"
|
||||||
help="Number of llm steps to run (includes prefill and decode) "
|
"Setting it to 1, profiles only the prefill step.\n"
|
||||||
"- default={OUTPUT_LEN_DEFAULT}")
|
"Setting it to 2, profiles the prefill and first decode step\n"
|
||||||
|
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
|
||||||
|
"and so on ...")
|
||||||
|
|
||||||
|
run_to_completion_parser = subparsers.add_parser(
|
||||||
|
"run_to_completion",
|
||||||
|
help="This variation profiles all the engine.step() invocations"
|
||||||
|
"until the engine exhausts all submitted requests.")
|
||||||
|
run_to_completion_parser.add_argument(
|
||||||
|
'-n',
|
||||||
|
'--complete-num-requests-per-step',
|
||||||
|
type=int,
|
||||||
|
help=
|
||||||
|
"Complete complete_num_requests_per_step requests every decode step."
|
||||||
|
"For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
|
||||||
|
"the profiler is run for 6 engine steps, with the steps processing, "
|
||||||
|
"128, 128, 96, 64, 32, 1 requests respectively.\n"
|
||||||
|
"Note that we tack-on a one-request step at the end as it is often "
|
||||||
|
"useful.")
|
||||||
|
|
||||||
EngineArgs.add_cli_args(parser)
|
EngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context = ProfileContext(
|
context = ProfileContext(
|
||||||
engine_args=EngineArgs.from_cli_args(args),
|
engine_args=EngineArgs.from_cli_args(args),
|
||||||
**{
|
**{
|
||||||
|
@ -34,9 +34,10 @@ if __name__ == "__main__":
|
|||||||
"examples/offline_profile.py")
|
"examples/offline_profile.py")
|
||||||
parser.add_argument("--phase",
|
parser.add_argument("--phase",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["prefill", "decode_1"],
|
|
||||||
required=True,
|
required=True,
|
||||||
help="The phase to print the table for.")
|
help="The phase to print the table for. This is either"
|
||||||
|
"prefill or decode_n, where n is the decode step "
|
||||||
|
"number")
|
||||||
parser.add_argument("--table",
|
parser.add_argument("--table",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["summary", "model"],
|
choices=["summary", "model"],
|
||||||
@ -49,6 +50,10 @@ if __name__ == "__main__":
|
|||||||
with open(args.json_trace) as f:
|
with open(args.json_trace) as f:
|
||||||
profile_data = json.load(f)
|
profile_data = json.load(f)
|
||||||
|
|
||||||
|
assert args.phase in profile_data, \
|
||||||
|
(f"Cannot find phase {args.phase} in profile data. Choose one among"
|
||||||
|
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
|
||||||
|
|
||||||
if args.table == "summary":
|
if args.table == "summary":
|
||||||
entries_and_depths = flatten_entries(
|
entries_and_depths = flatten_entries(
|
||||||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
|
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
|
||||||
|
@ -151,16 +151,31 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
|||||||
"scaled_int8_quant" in op_name:
|
"scaled_int8_quant" in op_name:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# LoRA ops
|
||||||
|
def is_sgmv_shrink(op_name: str):
|
||||||
|
return "sgmv_shrink" in op_name
|
||||||
|
|
||||||
|
def is_sgmv_expand(op_name: str):
|
||||||
|
return "sgmv_expand" in op_name
|
||||||
|
|
||||||
|
def is_bgmv_shrink(op_name: str):
|
||||||
|
return "bgmv_shrink" in op_name
|
||||||
|
|
||||||
|
def is_bgmv_expand(op_name: str):
|
||||||
|
return "bgmv_expand" in op_name
|
||||||
|
|
||||||
|
def is_cutlass_gemm_op(op_name: str):
|
||||||
|
return "void cutlass::Kernel" in op_name or \
|
||||||
|
"void cutlass::device_kernel" in op_name
|
||||||
|
|
||||||
def is_gemm_op(op_name: str):
|
def is_gemm_op(op_name: str):
|
||||||
if is_quant(op_name):
|
if is_quant(op_name):
|
||||||
return False
|
return False
|
||||||
if "xmma_gemm" in op_name or \
|
return is_cutlass_gemm_op(op_name) or \
|
||||||
|
"xmma_gemm" in op_name or \
|
||||||
"gemv2T_kernel" in op_name or \
|
"gemv2T_kernel" in op_name or \
|
||||||
"splitKreduce" in op_name or \
|
"splitKreduce" in op_name or \
|
||||||
"void cutlass::Kernel" in op_name or \
|
"s16816gemm" in op_name
|
||||||
"void cutlass::device_kernel" in op_name or \
|
|
||||||
"s16816gemm" in op_name:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_elementwise_op(op_name: str):
|
def is_elementwise_op(op_name: str):
|
||||||
return "elementwise_kernel" in op_name
|
return "elementwise_kernel" in op_name
|
||||||
@ -211,6 +226,18 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
|||||||
quant_ops = list(filter(lambda x: is_quant(x), ops))
|
quant_ops = list(filter(lambda x: is_quant(x), ops))
|
||||||
ops = list(filter(lambda x: x not in quant_ops, ops))
|
ops = list(filter(lambda x: x not in quant_ops, ops))
|
||||||
|
|
||||||
|
sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
|
||||||
|
ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
|
||||||
|
sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
|
||||||
|
ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
|
||||||
|
bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
|
||||||
|
ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
|
||||||
|
bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
|
||||||
|
ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))
|
||||||
|
|
||||||
|
cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
|
||||||
|
ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))
|
||||||
|
|
||||||
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
|
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
|
||||||
ops = list(filter(lambda x: x not in gemm_ops, ops))
|
ops = list(filter(lambda x: x not in gemm_ops, ops))
|
||||||
|
|
||||||
@ -257,6 +284,24 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
|||||||
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
|
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
|
||||||
if len(quant_ops):
|
if len(quant_ops):
|
||||||
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
|
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
|
||||||
|
|
||||||
|
if len(sgmv_shrink_ops):
|
||||||
|
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
|
||||||
|
axis=1)
|
||||||
|
if len(sgmv_expand_ops):
|
||||||
|
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
|
||||||
|
axis=1)
|
||||||
|
if len(bgmv_shrink_ops):
|
||||||
|
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
|
||||||
|
axis=1)
|
||||||
|
if len(bgmv_expand_ops):
|
||||||
|
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
|
||||||
|
axis=1)
|
||||||
|
|
||||||
|
if len(cutlass_gemm_ops):
|
||||||
|
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
|
||||||
|
axis=1)
|
||||||
|
|
||||||
if len(gemm_ops):
|
if len(gemm_ops):
|
||||||
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
|
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
|
||||||
if len(rms_norm_ops):
|
if len(rms_norm_ops):
|
||||||
@ -296,7 +341,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
|||||||
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
||||||
axis=1)
|
axis=1)
|
||||||
|
|
||||||
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
|
trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
|
||||||
|
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
|
||||||
|
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
|
||||||
vocab_embed_ops + mem_ops + elementwise_ops +
|
vocab_embed_ops + mem_ops + elementwise_ops +
|
||||||
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
|
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
|
||||||
nccl_other_ops + cross_device_reduce_1stage_ops +
|
nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||||
@ -315,7 +362,14 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
|||||||
plot_title: str,
|
plot_title: str,
|
||||||
output: Optional[Path] = None):
|
output: Optional[Path] = None):
|
||||||
|
|
||||||
|
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
|
||||||
|
phase_df = traces_df.query(f'phase == "{phase}"')
|
||||||
|
descs = phase_df['phase_desc'].to_list()
|
||||||
|
assert all([desc == descs[0] for desc in descs])
|
||||||
|
return descs[0]
|
||||||
|
|
||||||
phases = traces_df['phase'].unique()
|
phases = traces_df['phase'].unique()
|
||||||
|
phase_descs = [get_phase_description(traces_df, p) for p in phases]
|
||||||
traces_df = traces_df.pivot_table(index="phase",
|
traces_df = traces_df.pivot_table(index="phase",
|
||||||
columns="name",
|
columns="name",
|
||||||
values=plot_metric,
|
values=plot_metric,
|
||||||
@ -324,7 +378,8 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
|||||||
traces_df = group_trace_by_operations(traces_df)
|
traces_df = group_trace_by_operations(traces_df)
|
||||||
|
|
||||||
# Make the figure
|
# Make the figure
|
||||||
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
|
fig_size_x = max(5, len(phases))
|
||||||
|
fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
|
||||||
|
|
||||||
# Draw the stacked bars
|
# Draw the stacked bars
|
||||||
ops = list(traces_df)
|
ops = list(traces_df)
|
||||||
@ -332,7 +387,7 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
|||||||
for op in ops:
|
for op in ops:
|
||||||
values = [traces_df[op][phase] for phase in phases]
|
values = [traces_df[op][phase] for phase in phases]
|
||||||
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
|
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
|
||||||
ax.bar(phases, values, label=op, bottom=bottom)
|
ax.bar(phase_descs, values, label=op, bottom=bottom)
|
||||||
bottom = [bottom[j] + values[j] for j in range(len(phases))]
|
bottom = [bottom[j] + values[j] for j in range(len(phases))]
|
||||||
|
|
||||||
# Write the values as text on the bars
|
# Write the values as text on the bars
|
||||||
@ -390,6 +445,14 @@ def main(
|
|||||||
["name"]] = "others"
|
["name"]] = "others"
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
def get_phase_description(key: str) -> str:
|
||||||
|
num_running_seqs = profile_json[key]['metadata'][
|
||||||
|
'num_running_seqs']
|
||||||
|
if num_running_seqs is not None:
|
||||||
|
return f"{key}-seqs-{num_running_seqs}"
|
||||||
|
else:
|
||||||
|
return key
|
||||||
|
|
||||||
# Get data for each key
|
# Get data for each key
|
||||||
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
|
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
|
||||||
|
|
||||||
@ -413,6 +476,7 @@ def main(
|
|||||||
# Fill in information about the step-keys
|
# Fill in information about the step-keys
|
||||||
for trace_df, step_key in zip(trace_dfs, step_keys):
|
for trace_df, step_key in zip(trace_dfs, step_keys):
|
||||||
trace_df['phase'] = step_key
|
trace_df['phase'] = step_key
|
||||||
|
trace_df['phase_desc'] = get_phase_description(step_key)
|
||||||
|
|
||||||
# Combine all data frames so they can be put in a single plot
|
# Combine all data frames so they can be put in a single plot
|
||||||
traces_df = pd.concat(trace_dfs)
|
traces_df = pd.concat(trace_dfs)
|
||||||
@ -426,12 +490,16 @@ def main(
|
|||||||
def make_plot_title_suffix(profile_json: dict) -> str:
|
def make_plot_title_suffix(profile_json: dict) -> str:
|
||||||
context = profile_json["context"]
|
context = profile_json["context"]
|
||||||
sparsity = context.get('sparsity', None)
|
sparsity = context.get('sparsity', None)
|
||||||
return (f"{context['model']}\n"
|
run_type = \
|
||||||
|
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
|
||||||
|
(f'Complete {context["complete_num_requests_per_step"]} per '
|
||||||
|
f'step; Run till completion')
|
||||||
|
return (f"{context['engine_args']['model']}\n"
|
||||||
f"Batch={context['batch_size']}, "
|
f"Batch={context['batch_size']}, "
|
||||||
f"PromptLen={context['prompt_len']}, "
|
f"PromptLen={context['prompt_len']}, "
|
||||||
f"OutputLen={context['output_len']},"
|
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
|
||||||
f"NumGpus={context['tensor_parallel_size']}"
|
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
|
||||||
f"{', Sparsity ' + sparsity if sparsity else ''}")
|
f"Run Type: {run_type}")
|
||||||
|
|
||||||
profile_json = None
|
profile_json = None
|
||||||
with open(json_trace) as f:
|
with open(json_trace) as f:
|
||||||
|
@ -72,6 +72,9 @@ class LayerwiseProfileResults(profile):
|
|||||||
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
|
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
|
||||||
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
|
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
|
||||||
|
|
||||||
|
# profile metadata
|
||||||
|
num_running_seqs: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._build_correlation_map()
|
self._build_correlation_map()
|
||||||
self._build_module_tree()
|
self._build_module_tree()
|
||||||
@ -127,6 +130,9 @@ class LayerwiseProfileResults(profile):
|
|||||||
|
|
||||||
def convert_stats_to_dict(self) -> str:
|
def convert_stats_to_dict(self) -> str:
|
||||||
return {
|
return {
|
||||||
|
"metadata": {
|
||||||
|
"num_running_seqs": self.num_running_seqs
|
||||||
|
},
|
||||||
"summary_stats":
|
"summary_stats":
|
||||||
self._convert_stats_tree_to_dict(self._summary_stats_tree),
|
self._convert_stats_tree_to_dict(self._summary_stats_tree),
|
||||||
"model_stats":
|
"model_stats":
|
||||||
@ -338,7 +344,15 @@ class LayerwiseProfileResults(profile):
|
|||||||
|
|
||||||
class layerwise_profile(profile):
|
class layerwise_profile(profile):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, num_running_seqs: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
layerwise profile constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_running_seqs (Optional[int], optional): When given,
|
||||||
|
num_running_seqs will be passed to LayerProfileResults for metadata
|
||||||
|
update. Defaults to None.
|
||||||
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
record_shapes=True,
|
record_shapes=True,
|
||||||
@ -346,9 +360,13 @@ class layerwise_profile(profile):
|
|||||||
with_modules=True,
|
with_modules=True,
|
||||||
experimental_config=_ExperimentalConfig(verbose=True))
|
experimental_config=_ExperimentalConfig(verbose=True))
|
||||||
|
|
||||||
|
self.num_running_seqs = num_running_seqs
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return super().__enter__()
|
return super().__enter__()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
super().__exit__(exc_type, exc_val, exc_tb)
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
self.results = LayerwiseProfileResults(self.profiler.kineto_results)
|
self.results = LayerwiseProfileResults(
|
||||||
|
self.profiler.kineto_results,
|
||||||
|
num_running_seqs=self.num_running_seqs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user