# SPDX-License-Identifier: Apache-2.0 import argparse import copy import itertools import math import os import pickle as pkl import time from collections.abc import Iterable from dataclasses import dataclass from itertools import product from typing import Callable, Optional import pandas as pd import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, marlin_zero_points) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import ( pack_rows, quantize_weights) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_TP_SIZES = [1] NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) if NVTX_PROFILE: import nvtx def terse_type_name(dt): return { torch.bfloat16: "bf16", torch.float16: "fp16", torch.int8: "int8", torch.float8_e4m3fn: "fp8", torch.float: "float", torch.int: "int", }[dt] @dataclass class BenchmarkTensors: w_ref: torch.Tensor a: torch.Tensor w_q: torch.Tensor group_size: Optional[int] wtype: ScalarType w_g_s: torch.Tensor w_g_zp: Optional[torch.Tensor] w_ch_s: Optional[torch.Tensor] w_tok_s: Optional[torch.Tensor] @dataclass class TypeConfig: act_type: torch.dtype weight_type: ScalarType output_type: Optional[torch.dtype] group_scale_type: Optional[torch.dtype] group_zero_type: Optional[torch.dtype] channel_scale_type: Optional[torch.dtype] token_scale_type: Optional[torch.dtype] def rand_data(shape, dtype=torch.float16, scale=1): if dtype.is_floating_point: return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) else: return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") def quantize_and_pack(atype: torch.dtype, w: torch.Tensor, wtype: ScalarType, stype: Optional[torch.dtype], group_size: Optional[int], zero_points: bool = False): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( w, wtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps ref_zero_points_after_scales=True) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) return w_ref, w_q, w_s, w_zp def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]) -> list[BenchmarkTensors]: m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)) a = rand_data((m, k), types.act_type, scale=5) benchmark_tensors: list[BenchmarkTensors] = [] for _ in range(num_weights): w = rand_data((k, n), types.act_type, scale=5) if types.group_scale_type is not None: w = w.to(types.group_scale_type) if w.dtype.itemsize == 1: w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( a.dtype, w, types.weight_type, types.group_scale_type, group_size, types.group_zero_type is not None) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) w_ref = w_ref.to(torch.float32) w_ch_s = None if types.channel_scale_type is None else\ rand_data((n,), types.channel_scale_type) w_tok_s = None if types.token_scale_type is None else\ rand_data((m,), types.token_scale_type) benchmark_tensors.append( BenchmarkTensors(w_ref=w_ref, a=a, w_q=w_q_packed, wtype=types.weight_type, w_g_s=w_s, w_g_zp=w_zp, group_size=group_size, w_ch_s=w_ch_s, w_tok_s=w_tok_s)) return benchmark_tensors def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: a = bt.a w = bt.w_ref.to(bt.a.dtype) # use float reference tensor if a.dtype not in [torch.float16, torch.bfloat16]: a = a.to(torch.float16) w = w.to(torch.float16) return lambda: torch.matmul(a, w) def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: if bt.w_ch_s is not None and bt.w_tok_s is not None: scale_a = bt.w_tok_s.to(torch.float32) scale_b = bt.w_ch_s.to(torch.float32) else: scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() return lambda: ops.cutlass_scaled_mm( bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: device = bt.a.device workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) if bt.w_g_zp is None: w_zp = torch.empty(0, dtype=torch.int, device=device) else: w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits) if bt.group_size is None: w_s = torch.tensor([], device="cuda", dtype=torch.half) else: w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size) sort_indices = torch.empty(0, dtype=torch.int, device=device) g_idx = torch.empty(0, dtype=torch.int, device=device) w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits) if bt.a.dtype.is_floating_point: assert bt.w_ch_s is None assert bt.w_tok_s is None assert bt.group_size is not None fn = lambda: ops.gptq_marlin_gemm(a=bt.a, b_q_weight=w_q, b_scales=w_s, b_zeros=w_zp, g_idx=g_idx, perm=sort_indices, workspace=workspace.scratch, b_q_type=bt.wtype, size_m=bt.a.shape[0], size_n=bt.w_ref.shape[1], size_k=bt.w_ref.shape[0], is_k_full=True, is_zp_float=False) else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 if bt.w_ch_s is not None: s_ch = bt.w_ch_s.to(torch.float32) else: s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) if bt.w_tok_s is not None: s_tok = bt.w_tok_s.to(torch.float32) else: s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) fn = lambda: ops.marlin_qqq_gemm(a=bt.a, b_q_weight=w_q, s_group=w_s, s_tok=s_tok, s_ch=s_ch, workspace=workspace.scratch, size_m=bt.a.shape[0], size_n=bt.w_ref.shape[1], size_k=bt.w_ref.shape[0]) return fn def machete_create_bench_fn(bt: BenchmarkTensors, out_type=torch.dtype, schedule=None) -> Callable: w_q = bt.w_q.t().contiguous().t() # make col major w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype) w_g_zp = bt.w_g_zp if w_g_zp is not None: w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) return lambda: ops.machete_mm( a=bt.a, b_q=w_q, b_type=bt.wtype, b_group_scales=bt.w_g_s, b_group_zeros=w_g_zp, b_group_size=bt.group_size, b_channel_scales=bt.w_ch_s, a_token_scales=bt.w_tok_s, out_type=out_type, schedule=schedule, ) # impl # bench def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]): min_run_time = 1 if not NVTX_PROFILE else 0.1 res = TBenchmark.Timer( stmt=""" for fn in fns: fn() """, globals={ "fns": fns }, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) if NVTX_PROFILE: with nvtx.annotate("mm-bench"), nvtx.annotate( f"{label}|{sub_label}|{description}"): fns[0]() return res _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None def bench(types: TypeConfig, group_size: int, m: int, k: int, n: int, label: str, sub_label: str, sweep_schedules: bool = True) -> list[TMeasurement]: benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) sub_label += f", L={len(benchmark_tensors)}" name_type_string = f"W{types.weight_type}"+\ f"-A{terse_type_name(types.act_type)}" if types.group_scale_type is not None: name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" if types.group_zero_type is not None: name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" if group_size is not None: name_type_string += f"-G{group_size}" if types.channel_scale_type is not None: name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" if types.token_scale_type is not None: name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" timers = [] # pytorch impl timers.append( bench_fns( label, sub_label, "torch.matmul (fp16)", [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors])) if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: timers.append( bench_fns( label, sub_label, f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors ])) if types.act_type != torch.float8_e4m3fn: timers.append( bench_fns(label, sub_label, f"marlin ({name_type_string})", [marlin_create_bench_fn(bt) for bt in benchmark_tensors])) # machete timers.append( bench_fns(label, sub_label, f"machete ({name_type_string})", [ machete_create_bench_fn(bt, out_type=types.output_type) for bt in benchmark_tensors ])) if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS print("Finding best schedule for machete") best = None best_schedule = None schedules = ops.machete_supported_schedules( a_type=types.act_type, b_type=types.weight_type, group_scales_type=types.group_scale_type, group_zeros_type=types.group_zero_type, token_scales_type=types.token_scale_type, channel_scales_type=types.channel_scale_type, out_type=types.output_type) if schedules is None or len(schedules) == 0: raise ValueError("No schedules found to sweep") for schedule in reversed(schedules): schedule_M = int(schedule.split("_")[0].split("x")[1]) # Prune known bad schedules if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue res = bench_fns(label, sub_label, "machete_best", [ machete_create_bench_fn( bt, out_type=types.output_type, schedule=schedule) for bt in benchmark_tensors ]) results_row = { "M": m, "K": k, "N": n, "group_size": group_size, "schedule": schedule, "median": res.median, } if _SWEEP_SCHEDULES_RESULTS is None: _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( columns=results_row.keys()) _SWEEP_SCHEDULES_RESULTS.\ loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: best = res best_schedule = schedule print("Best schedule:", best_schedule) timers.append(best) return timers # runner def print_timers(timers: list[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: types = TypeConfig( act_type=args.act_type, weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ else scalar_types.uint4, output_type=args.out_type, group_scale_type=args.group_scale_type, group_zero_type=args.group_zero_type, channel_scale_type=args.channel_scale_type, token_scale_type=args.token_scale_type, ) results: list[TMeasurement] = [] for m, k, n in MKNs: timers = bench(types, args.group_size, m, k, n, f"{args.act_type}-gemm", f"MKN=({m}x{k}x{n})", sweep_schedules=args.sweep_schedules) print_timers(timers) results.extend(timers) return results # output makers def make_output( data: list[TMeasurement], MKNs: Iterable[tuple[int, int, int]], base_description: str, timestamp=None, ): print(f"== All Results {base_description} ====") print_timers(data) # pickle all the results timestamp = int(time.time()) if timestamp is None else timestamp with open(f"{base_description}-{timestamp}.pkl", "wb") as f: pkl.dump(data, f) # argparse runners def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") def run_range_bench(args): m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) m_increment, k_increment, n_increment = \ (int(x) for x in args.dim_increment.split(",")) Ms = list(range(m_start, m_end + 1, m_increment)) Ks = list(range(k_start, k_end + 1, k_increment)) Ns = list(range(n_start, n_end + 1, n_increment)) MKNs = list(product(Ms, Ks, Ns)) data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"range_bench-{args.dtype}") def run_model_bench(args): print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: KNs = [] for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): KN[tp_split_dim] = KN[tp_split_dim] // tp_size KNs.append(KN) return KNs model_bench_data = [] models_tps = list(itertools.product(args.models, args.tp_sizes)) for model, tp_size in models_tps: Ms = args.batch_sizes KNs = model_shapes(model, tp_size) MKNs = [] for m in Ms: for k, n in KNs: MKNs.append((m, k, n)) data = run(args, MKNs) model_bench_data.append(data) type_string = f"{args.act_type}" # Print all results for data, model_tp in zip(model_bench_data, models_tps): model, tp_size = model_tp print(f"== Results {type_string} {model}-TP{tp_size} ====") print_timers(data) timestr = time.strftime("%Y%m%d-%H%M%S") all_results = [] for d in model_bench_data: all_results.extend(d) # pickle all data with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: args_dict = vars(args) args_dict.pop("func") pkl.dump({ "args": args_dict, "results": all_results, }, f) if __name__ == "__main__": def to_torch_dtype(dt): return { "bfloat16": torch.bfloat16, "float16": torch.float16, "int8": torch.int8, "float8_e4m3fn": torch.float8_e4m3fn, "int": torch.int, "float": torch.float, }[dt] class ToTorchDtype(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, to_torch_dtype(values)) parser = FlexibleArgumentParser( description=""" Benchmark Machete GEMM. To run square GEMMs: python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 To run constant N and K and sweep M: python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 To run dimensions from a model: python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--act-type", action=ToTorchDtype, required=True, choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], ) parser.add_argument( "--group-scale-type", action=ToTorchDtype, choices=['bfloat16', 'float16'], ) parser.add_argument( "--group-zero-type", type=to_torch_dtype, choices=['bfloat16', 'float16'], ) parser.add_argument( "--channel-scale-type", action=ToTorchDtype, choices=['float'], ) parser.add_argument( "--token-scale-type", action=ToTorchDtype, choices=['float'], ) parser.add_argument( "--out-type", action=ToTorchDtype, choices=['bfloat16', 'float16'], ) parser.add_argument( "--group-size", type=int, help="Available options are ['None', '-1', '128'], default=128", default=128, ) parser.add_argument( "--sweep-schedules", action="store_true", help="Run a sweep over all supported schedules", ) parser.add_argument("--sweep-csv-out", help="CSV to store sweep results", default="sch_sweep_results.csv") subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") square_parser.add_argument("--dim-start", type=int, required=True) square_parser.add_argument("--dim-end", type=int, required=True) square_parser.add_argument("--dim-increment", type=int, required=True) square_parser.set_defaults(func=run_square_bench) range_parser = subparsers.add_parser("range_bench") range_parser.add_argument( "--dim-start", type=str, required=True, help="Start value for M,K,N as common separated list") range_parser.add_argument( "--dim-end", type=str, required=True, help="End value (inclusive) for M,K,N as common separated list") range_parser.add_argument( "--dim-increment", type=str, required=True, help="Increment value for M,K,N as common separated list") range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") model_parser.add_argument( "--models", nargs="+", type=str, default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) model_parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) model_parser.add_argument("--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out args.func(args) if _SWEEP_SCHEDULES_RESULTS is not None: _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)