From 5288c06aa03b100eab4f873452b65da941a1a232 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 20 Aug 2024 09:09:33 -0400 Subject: [PATCH] [Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel (#7174) --- .gitignore | 3 + CMakeLists.txt | 40 + benchmarks/kernels/benchmark_machete.py | 372 +++++ benchmarks/kernels/graph_machete_bench.py | 64 + benchmarks/kernels/weight_shapes.py | 43 + csrc/cuda_utils.h | 10 + csrc/cutlass_extensions/cute_utils.cuh | 68 + csrc/cutlass_extensions/torch_utils.hpp | 154 ++ .../vllm_collective_builder.cuh | 43 + csrc/cutlass_extensions/vllm_custom_types.cuh | 50 + .../vllm_cutlass_library_extension.py | 49 + .../vllm_numeric_conversion.cuh | 795 +++++++++ csrc/ops.h | 19 + csrc/quantization/machete/Readme.md | 45 + csrc/quantization/machete/generate.py | 446 +++++ .../machete/machete_collective_builder.cuh | 33 + .../machete/machete_interleaving_utils.cuh | 35 + .../quantization/machete/machete_mainloop.cuh | 1473 +++++++++++++++++ .../machete/machete_mm_kernel.cuh | 237 +++ .../machete/machete_mm_launcher.cuh | 95 ++ .../machete/machete_prepack_kernel.cuh | 62 + .../machete/machete_prepack_launcher.cuh | 71 + .../machete/machete_prepacked_layout.cuh | 220 +++ csrc/quantization/machete/machete_pytorch.cu | 79 + csrc/torch_bindings.cpp | 15 + tests/kernels/test_machete_gemm.py | 272 +++ vllm/_custom_ops.py | 26 + .../layers/quantization/utils/quant_utils.py | 11 +- 28 files changed, 4828 insertions(+), 2 deletions(-) create mode 100644 benchmarks/kernels/benchmark_machete.py create mode 100644 benchmarks/kernels/graph_machete_bench.py create mode 100644 benchmarks/kernels/weight_shapes.py create mode 100644 csrc/cutlass_extensions/cute_utils.cuh create mode 100644 csrc/cutlass_extensions/torch_utils.hpp create mode 100644 csrc/cutlass_extensions/vllm_collective_builder.cuh create mode 100644 csrc/cutlass_extensions/vllm_custom_types.cuh create mode 100644 csrc/cutlass_extensions/vllm_cutlass_library_extension.py create mode 100644 csrc/cutlass_extensions/vllm_numeric_conversion.cuh create mode 100644 csrc/quantization/machete/Readme.md create mode 100644 csrc/quantization/machete/generate.py create mode 100644 csrc/quantization/machete/machete_collective_builder.cuh create mode 100644 csrc/quantization/machete/machete_interleaving_utils.cuh create mode 100644 csrc/quantization/machete/machete_mainloop.cuh create mode 100644 csrc/quantization/machete/machete_mm_kernel.cuh create mode 100644 csrc/quantization/machete/machete_mm_launcher.cuh create mode 100644 csrc/quantization/machete/machete_prepack_kernel.cuh create mode 100644 csrc/quantization/machete/machete_prepack_launcher.cuh create mode 100644 csrc/quantization/machete/machete_prepacked_layout.cuh create mode 100644 csrc/quantization/machete/machete_pytorch.cu create mode 100644 tests/kernels/test_machete_gemm.py diff --git a/.gitignore b/.gitignore index 2dfbf64d..761b00ac 100644 --- a/.gitignore +++ b/.gitignore @@ -87,6 +87,9 @@ target/ profile_default/ ipython_config.py +# generated files +**/generated/** + # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: diff --git a/CMakeLists.txt b/CMakeLists.txt index d47f1bb3..c8d4aaed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -227,6 +227,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "-gencode arch=compute_90a,code=sm_90a") endif() + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH + ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + message(STATUS "Machete generation completed successfully.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}") + + # See comment above for scaled_mm_c3x (same if condition) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + ${MACHETE_GEN_SOURCES} + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + + # Add pytorch binding + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) endif() define_gpu_extension_target( diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py new file mode 100644 index 00000000..ca45cba6 --- /dev/null +++ b/benchmarks/kernels/benchmark_machete.py @@ -0,0 +1,372 @@ +import argparse +import copy +import itertools +import math +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, pack_rows, quantize_weights) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] +DEFAULT_TP_SIZES = [1] + + +def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # make col major + return ops.machete_prepack_B(w_q, wtype) + + +def make_bench_tensors( + atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, + k: int +) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, + torch.tensor]]]: + assert wtype.is_integer(), "TODO: support floating point weights" + + # we want to make sure that weights don't fit into L2 cache between runs so + # we construct enough weights to exceed L2 cache, which is 50mb on a H100 + # so we target total weight size > 2*50mb + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) + + a = torch.randn((m, k), device="cuda", dtype=atype) * 5 + weights = [ + torch.randn((k, n), device="cuda", dtype=atype) + for _ in range(num_weights) + ] + quanitized_weights = [ + quantize_weights(w, wtype, group_size) for w in weights + ] + + return a, quanitized_weights + + +# impl + + +# bench +def bench_fn(label: str, sub_label: str, description: str, + fn: Callable) -> TMeasurement: + + min_run_time = 1 + return TBenchmark.Timer( + stmt="fn()", + globals={ + "fn": fn + }, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def loop_over_weights( + a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, + torch.tensor, torch.tensor]], + fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], + None]): + for w_ref, w_q, w_s, _ in weights: + fn(a, w_ref, w_q, w_s) + + +def bench(atype: torch.dtype, + wtype: ScalarType, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + benchmark_marlinv1: bool = True, + sweep_schedules: bool = True) -> Iterable[TMeasurement]: + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) + sub_label += f", L={len(weights)}" + + weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + timers = [] + # pytorch impl + timers.append( + bench_fn( + label, sub_label, "torch.matmul", lambda: loop_over_weights( + a, + weights, + lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), + ))) + + if benchmark_marlinv1: + w_ref = weights[0][0] + + w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) + sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) + g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) + + def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: + w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) + return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, + wtype.size_bits) + + def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: + return marlin_permute_scales(w_s, *w_ref.shape, group_size) + + weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), + marlinv1_permute_scales(w_s), w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + # marlinv1 + timers.append( + bench_fn( + label, sub_label, "marlin_orig", lambda: loop_over_weights( + a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. + gptq_marlin_gemm(a, + w_q, + w_s, + w_zp_empty, + g_idx, + sort_indices, + workspace.scratch, + wtype, + size_m=a.shape[0], + size_n=w_ref.shape[1], + size_k=w_ref.shape[0], + is_k_full=True)))) + + # machete + timers.append( + bench_fn( + label, sub_label, "machete_heuristic", lambda: loop_over_weights( + a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( + a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + + if sweep_schedules: + print("Finding best schedule for machete") + best = None + best_schedule = None + schedules = ops.machete_supported_schedules(wtype) + for schedule in reversed(schedules): + + def run(a, _, w_q, w_s, schedule=schedule): + ops.machete_gemm(a, + w_q, + wtype, + w_s, + b_group_size=group_size, + schedule=schedule) + + res = bench_fn(label, sub_label, "machete_best", + lambda: loop_over_weights(a, weights_machete, run)) + + print(f" {res.median:5.5} ", schedule) + if not best or res.median < best.median: + best = res + best_schedule = schedule + print("Best schedule:", best_schedule) + timers.append(best) + + return timers + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, sweep_schedules: bool, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, + scalar_types.uint4b8, + 128, + m, + k, + n, + f"{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=sweep_schedules) + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None, +): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "bfloat16": + return torch.bfloat16 + if dt == "float16": + return torch.float16 + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Machete GEMM. + + To run square GEMMs: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['bfloat16', 'float16']", + ) + parser.add_argument( + "--sweep-schedules", + action="store_true", + help="Run a sweep over all supported schedules", + ) + subparsers = parser.add_subparsers(dest="cmd", required=True) + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py new file mode 100644 index 00000000..1d076ed6 --- /dev/null +++ b/benchmarks/kernels/graph_machete_bench.py @@ -0,0 +1,64 @@ +import math +import pickle +import re +from collections import defaultdict +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.utils import FlexibleArgumentParser + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('filename', type=str) + + args = parser.parse_args() + + with open(args.filename, 'rb') as f: + data: List[TMeasurement] = pickle.load(f) + + results = defaultdict(lambda: list()) + for v in data: + result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) + if result is not None: + KN = result.group(1) + else: + raise Exception("MKN not found") + result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) + if result is not None: + M = result.group(1) + else: + raise Exception("MKN not found") + + kernel = v.task_spec.description + results[KN].append({ + "kernel": kernel, + "batch_size": M, + "median": v.median + }) + + rows = int(math.ceil(len(results) / 2)) + fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) + axs = axs.flatten() + axs_idx = 0 + for shape, data in results.items(): + plt.sca(axs[axs_idx]) + df = pd.DataFrame(data) + sns.lineplot(data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2") + plt.title(f"Shape: {shape}") + plt.ylabel("time (median, s)") + axs_idx += 1 + plt.tight_layout() + plt.savefig("graph_machete_bench.pdf") diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py new file mode 100644 index 00000000..25ec9d60 --- /dev/null +++ b/benchmarks/kernels/weight_shapes.py @@ -0,0 +1,43 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 73944f4c..c3522421 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,15 @@ #pragma once +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ + #define DEVICE_INLINE __forceinline__ __device__ + #define HOST_INLINE __forceinline__ __host__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh new file mode 100644 index 00000000..1842fab8 --- /dev/null +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) + return true; + else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp new file mode 100644 index 00000000..1618a340 --- /dev/null +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,154 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + return tensor.stride(idx); + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + c10::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 00000000..085ee129 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_custom_types.cuh b/csrc/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 00000000..6146bdc1 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 00000000..4fcfcd31 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,49 @@ +import enum +from typing import Dict, Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecializedMixedInput = enum_auto() + TmaWarpSpecializedPingpongMixedInput = enum_auto() + TmaWarpSpecializedCooperativeMixedInput = enum_auto() + + +VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + } +} + +VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + } +} + +VLLMKernelScheduleTag: Dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + } + } diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 00000000..2ad914f8 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,795 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + CUTE_INVALID_CONTROL_PATH( + "InterleavedNumericArrayConverter not implemented\n"); + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// TODO (LucasWilkinson): Implement +// for Array <= Array + +// .... + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + if constexpr (sizeof(PackedSrc) == 1) { + return static_cast(reinterpret_cast(source)); + } else if constexpr (sizeof(PackedSrc) == 2) { + return static_cast(reinterpret_cast(source)); + } else { + static_assert(sizeof(PackedSrc) == 4); + return reinterpret_cast(source); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_reg(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ops.h b/csrc/ops.h index 60945999..6bf0cff2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -83,6 +83,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k); +namespace machete { + +std::vector supported_schedules( + vllm::ScalarTypeTorchPtr const& btype); + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + vllm::ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule); + +torch::Tensor prepack_B(torch::Tensor const& B, + vllm::ScalarTypeTorchPtr const& btype); + +}; // namespace machete + torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, diff --git a/csrc/quantization/machete/Readme.md b/csrc/quantization/machete/Readme.md new file mode 100644 index 00000000..9ddf8da9 --- /dev/null +++ b/csrc/quantization/machete/Readme.md @@ -0,0 +1,45 @@ +# Machete (Mixed Precision Cutlass-Based GEMM) + +Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin. + +## Overview + +Machete effectively performs + +``` +scale_type = w_s.dtype +compute_type = a.dtype +out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a +``` + +Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and +`w_z` is the quantization zeropoints. + +> **_NOTE:_** `w_z` is added after the scales so we can +use FMA operations, but this means they must have the scales pre-applied if the +supplied zeropoints assume that they will be subtracted before the scales are +applied. + +## API + +The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like: + +``` +from vllm import _custom_ops as ops + +... +W_q_packed = ops.machete_prepack_B(w_q, wtype) +output = ops.machete_gemm( + a, + b_q=W_q_packed, + b_type=wtype, + b_scales=w_s, + b_group_size=group_size +) +``` + +## Code Generation + +Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. + +New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. \ No newline at end of file diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py new file mode 100644 index 00000000..09a98a5d --- /dev/null +++ b/csrc/quantization/machete/generate.py @@ -0,0 +1,446 @@ +import itertools +import math +import os +import shutil +from collections.abc import Iterable +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import jinja2 +# yapf conflicts with isort for this block +# yapf: disable +from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, VLLMDataType, + VLLMDataTypeNames, VLLMDataTypeTag, + VLLMKernelScheduleTag) + +# yapf: enable + +# +# Generator templating +# + +DISPATCH_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { +using GemmDispatcher_ = GemmDispatcher< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints + +{% for s in schedules %}extern torch::Tensor +impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); +{% endfor %} +template <> +torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.schedule) { + {%- for cond, s in heuristic %} + {%if cond is not none%}if ({{cond}}) + {%- else %}else + {%- endif %} + return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + } + + {% for s in schedules %} + if (*args.schedule == "{{ gen_sch_name(s) }}") { + return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); + } + {% endfor %} + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.schedule); +} + +template <> +std::vector GemmDispatcher_::supported_schedules() { + return { + {% for s in schedules -%} + "{{ gen_sch_name(s) }}"{{ ", + " if not loop.last }}{%- endfor %} + }; +} + +}; // namespace machete +""" + +IMPL_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { +template +using Kernel = MacheteKernelTemplate< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Config, with_C, with_scales, with_zeropoints>; + +{% for sch in schedules %} +{% set schedule_name = gen_sch_name(sch) -%} +struct sch_{{schedule_name}} { + using TileShapeNM = Shape<{{ + to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; + using ClusterShape = Shape<{{ + to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>; + // TODO: Reimplement + // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; + using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}}; + using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +torch::Tensor +impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { + bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), + with_zeropoints = args.zeros.has_value(); + + {% for s in specializations %} + if (with_C == {{s.with_C|lower}} + && with_zeropoints == {{s.with_zeropoints|lower}} + && with_scales == {{s.with_scales|lower}}) { + return run_impl>(args); + }{% endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "for the sake of compile times and binary size machete_mm(..) is " + " not implemented for with_C=", with_C, ", with_scales=", with_scales, + ", with_zeropoints=", with_zeropoints, + " (for {{type_name}}_sch_{{schedule_name}})"); +} +{% endfor %} + +}; // namespace machete +""" + +PREPACK_TEMPLATE = """ +#include "../machete_prepack_launcher.cuh" + +namespace machete { +using PrepackBDispatcher_ = PrepackBDispatcher< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints + +using PrepackedLayoutB = PrepackedLayoutBTemplate< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; + +template <> +torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { + return prepack_impl(B); +} +}; // namespace machete +""" + +TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput +TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative + + +@dataclass +class ScheduleConfig: + tile_shape_mn: Tuple[int, int] + cluster_shape_mnk: Tuple[int, int, int] + kernel_schedule: MixedInputKernelScheduleType + epilogue_schedule: EpilogueScheduleType + tile_scheduler: TileSchedulerType + + +@dataclass +class TypeConfig: + element_a: DataType + element_b: Union[DataType, VLLMDataType] + element_b_scale: DataType + element_b_zeropoint: DataType + element_d: DataType + accumulator: DataType + + +@dataclass +class Specialization: + with_C: bool + with_zeropoints: bool + with_scales: bool + + +@dataclass +class ImplConfig: + type_config: TypeConfig + schedule_configs: List[ScheduleConfig] + specializations: List[Specialization] + heuristic: List[Tuple[Optional[str], ScheduleConfig]] + + +def generate_schedule_name(schedule_config: ScheduleConfig) -> str: + tile_shape = ( + f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" + ) + cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}") + kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ + .split("::")[-1] + epilogue_schedule = EpilogueScheduleTag[ + schedule_config.epilogue_schedule].split("::")[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ + .split("::")[-1] + + return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}") + + +# mostly unique shorter schedule_name +def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: + kernel_terse_names_replace = { + "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", + "TmaWarpSpecializedCooperative_": "TmaCoop_", + "StreamKScheduler": "streamK", + } + + schedule_name = generate_schedule_name(schedule_config) + for orig, terse in kernel_terse_names_replace.items(): + schedule_name = schedule_name.replace(orig, terse) + return schedule_name + + +# unique type_name +def generate_type_signature(kernel_type_config: TypeConfig): + element_a = VLLMDataTypeNames[kernel_type_config.element_a] + element_b = VLLMDataTypeNames[kernel_type_config.element_b] + element_d = VLLMDataTypeNames[kernel_type_config.element_d] + accumulator = VLLMDataTypeNames[kernel_type_config.accumulator] + element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale] + element_zeropoint = VLLMDataTypeNames[ + kernel_type_config.element_b_zeropoint] + + return (f"{element_a}{element_b}{element_d}" + f"{accumulator}{element_scale}{element_zeropoint}") + + +# non-unique shorter type_name +def generate_terse_type_signature(kernel_type_config: TypeConfig): + element_a = VLLMDataTypeNames[kernel_type_config.element_a] + element_b = VLLMDataTypeNames[kernel_type_config.element_b] + + return f"{element_a}{element_b}" + + +def is_power_of_two(n): + return (n != 0) and (n & (n - 1) == 0) + + +def to_cute_constant(value: List[int]): + + def _to_cute_constant(value: int): + if is_power_of_two(value): + return f"_{value}" + else: + return f"Int<{value}>" + + if isinstance(value, Iterable): + return [_to_cute_constant(value) for value in value] + else: + return _to_cute_constant(value) + + +template_globals = { + "DataTypeTag": VLLMDataTypeTag, + "KernelScheduleTag": VLLMKernelScheduleTag, + "EpilogueScheduleTag": EpilogueScheduleTag, + "TileSchedulerTag": TileSchedulerTag, + "to_cute_constant": to_cute_constant, + "gen_sch_name": generate_terse_schedule_name, +} + + +def create_template(template_str): + template = jinja2.Template(template_str) + template.globals.update(template_globals) + return template + + +mm_dispatch_template = create_template(DISPATCH_TEMPLATE) +mm_impl_template = create_template(IMPL_TEMPLATE) +prepack_dispatch_template = create_template(PREPACK_TEMPLATE) + + +def create_sources(impl_config: ImplConfig, num_impl_files=2): + sources = [] + + type_name = generate_type_signature(impl_config.type_config) + terse_type_name = generate_terse_type_signature(impl_config.type_config) + + sources.append(( + f"machete_mm_{terse_type_name}", + mm_dispatch_template.render(type_name=type_name, + type_config=impl_config.type_config, + schedules=impl_config.schedule_configs, + heuristic=impl_config.heuristic), + )) + + sources.append(( + f"machete_prepack_{terse_type_name}", + prepack_dispatch_template.render( + type_name=type_name, + type_config=impl_config.type_config, + ), + )) + + num_schedules = len(impl_config.schedule_configs) + schedules_per_file = math.ceil(num_schedules / num_impl_files) + for part, i in enumerate(range(0, num_schedules, schedules_per_file)): + file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + + sources.append(( + f"machete_mm_{terse_type_name}_impl_part{part}", + mm_impl_template.render( + type_name=type_name, + type_config=impl_config.type_config, + schedules=file_schedules, + specializations=impl_config.specializations, + ), + )) + return sources + + +def generate(): + # See csrc/quantization/machete/Readme.md, the Codegeneration for more info + # about how this works + SCRIPT_DIR = os.path.dirname(__file__) + + schedules = [ + ScheduleConfig( + tile_shape_mn=tile_shape_mn, + cluster_shape_mnk=cluster_shape_mnk, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + ) for tile_shape_mn, cluster_shape_mnk in ( + ((128, 16), (1, 1, 1)), + ((128, 32), (1, 1, 1)), + ((128, 64), (1, 1, 1)), + ((128, 128), (1, 1, 1)), + ) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, ) + for tile_scheduler in (TileSchedulerType.StreamK, ) + ] + + # For now we use the same heuristic for all types + default_heuristic = [ + ("M > 64", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + )), + ("M > 32", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + )), + ("M > 16", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + )), + (None, + ScheduleConfig(tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK)) + ] + + impl_configs = [] + + GPTQ_kernel_type_configs = list( + (TypeConfig( + element_a=element_a, + element_b=element_b, + element_b_scale=element_a, + element_b_zeropoint=element_a, + element_d=element_a, + accumulator=DataType.f32, + ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for element_a in (DataType.f16, DataType.bf16))) + + GPTQ_kernel_specializations = [ + Specialization(with_C=False, with_zeropoints=False, with_scales=True) + ] + + impl_configs += [ + ImplConfig(x[0], x[1], x[2], x[3]) + for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), + itertools.repeat(GPTQ_kernel_specializations), + itertools.repeat(default_heuristic)) + ] + + AWQ_kernel_type_configs = list( + (TypeConfig( + element_a=element_a, + element_b=element_b, + element_b_scale=element_a, + element_b_zeropoint=element_a, + element_d=element_a, + accumulator=DataType.f32, + ) for element_b in (DataType.u4, DataType.u8) + for element_a in (DataType.f16, DataType.bf16))) + + AWQ_kernel_specializations = [ + Specialization(with_C=False, with_zeropoints=True, with_scales=True) + ] + + impl_configs += [ + ImplConfig(x[0], x[1], x[2], x[3]) + for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), + itertools.repeat(AWQ_kernel_specializations), + itertools.repeat(default_heuristic)) + ] + + output_dir = os.path.join(SCRIPT_DIR, "generated") + + # Delete the "generated" directory if it exists + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + # Create the "generated" directory + os.makedirs(output_dir) + + # Render each group of configurations into separate files + for impl_config in impl_configs: + for filename, code in create_sources(impl_config): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") + + +if __name__ == "__main__": + generate() diff --git a/csrc/quantization/machete/machete_collective_builder.cuh b/csrc/quantization/machete/machete_collective_builder.cuh new file mode 100644 index 00000000..a74cf8b2 --- /dev/null +++ b/csrc/quantization/machete/machete_collective_builder.cuh @@ -0,0 +1,33 @@ +#pragma once + +#include "cutlass_extensions/vllm_collective_builder.cuh" +#include "machete_mainloop.cuh" + +namespace cutlass::gemm::collective { +using namespace cute; + +struct MacheteKernelTag {}; + +template +struct VLLMCollectiveBuilder< + MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_, + GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, + ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, + KernelScheduleType, + cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v)>> { + using CollectiveOp = machete::MacheteCollectiveMma< + ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, + AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, + StageCountType, KernelScheduleType>; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/quantization/machete/machete_interleaving_utils.cuh b/csrc/quantization/machete/machete_interleaving_utils.cuh new file mode 100644 index 00000000..d397f87f --- /dev/null +++ b/csrc/quantization/machete/machete_interleaving_utils.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace machete { + +using namespace cute; + +// get an interleaved block layout where each element consecutive element has a +// stride of bit_stride and the block width is blk_bit_width, +// examples: +// size_bits = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1 +// size_bits = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1) +// size_bits = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1) +// size_bits = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1) +template +CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() { + static_assert(blk_bit_width % bit_stride == 0); + static_assert(bit_stride % cute::sizeof_bits_v == 0); + + constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v; + + if constexpr (cute::sizeof_bits_v == bit_stride) { + // identity layout + return Layout>>{}; + } else { + constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v; + constexpr auto num_strides = elems_per_blk / elems_per_stride; + return Layout, Int>, + Stride, Int<1>>>{}; + } +} + +}; // namespace machete diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh new file mode 100644 index 00000000..3d574ad9 --- /dev/null +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -0,0 +1,1473 @@ +// +// Based off of: +// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Specifically: +// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Referred to as upstream from in the comments +// +// The main optimization machete implements compared to upstream is to prepack +// the weight matrix to more closely match the shape of the wgmma instructions +// allowing for wider (ideally 128bit) shared memory loads. For subbyte types +// this is done by packing values from multiple wgmma loads (for a single +// thread) into a single 128bit load. This is very similar to layout used in +// Marlin, although specific to the wgmma instructions. +// +// Since the wgmma instructions only support sourcing from registers for the A +// operand, and we want to upconvert/decompress the weight values/elements +// before feeding them into the tensor cores in registers, we need the weight +// matrix to be A. To achieve this we compute the transpose of Y = XW^t as +// Y^t = W^tX^t. This is mostly done outside of this file in +// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the +// quantized/narrow type and has the prepacked layout despite the API being: +// B_prepacked = machete_prepack_B(B) +// Y = machete_mm(A, B_prepacked) +// +#pragma once + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cutlass/detail/collective.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" + +namespace machete { + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cutlass::gemm::collective; +using namespace cutlass::gemm::collective::detail; + +template +struct MacheteCollectiveMma { + using Schedule = KernelScheduleType; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); + + public: + static constexpr bool ALayoutIsPrepacked = true; + + // Prepacked block shape (N is M in the transposed problem) + using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK; + // Prepacked blocks per dim for a single MMA tile + using PPBlocksPerTile_MK = decltype(make_shape( + size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}), + size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{}))); + + using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout; + + static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0, + "M in PPBlockShape_MK must evenly divide M TileShape_MNK"); + static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0, + "K in PPBlockShape_MK must evenly divide K TileShape_MNK"); + + using ArchTag = arch::Sm90; + using TileShape = TileShape_MNK; + using ClusterShape = ClusterShape_MNK; + using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>; + using StrideA = TagToStrideA_t; + using ElementB = ElementB_; + using StrideB = TagToStrideB_t; + using ElementAccumulator = ElementAccumulator_; + using ElementMma = ElementB; + using ElementATuple = + cute::conditional_t::value, + cute::tuple, ElementATuple_>; + + static constexpr cute::GMMA::Major GmmaMajorA = + gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + private: + // + // the setup section (until "section setup end") contains a combination of + // modified code from (used as a starting point): + // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` + // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` + // (upstream) + // + // however in-order to simplify the code we combine a lot of the logic from + // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes + // sense given that we have flexibility on layouts here. We also simplify the + // code by only supporting scales and zeros for A (in the transposed problem, + // B from an API perspective), also since we force A to be the narrow type + // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in + // the upstream also simplifying the code. This section includes new logic + // (compared ustream) for handling the prepacked-A layouts (in the transposed + // problem, B from an API perspective) + // + using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; + using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; + + static constexpr bool IsANarrow = cutlass::sizeof_bits::value < + cutlass::sizeof_bits::value; + static_assert(IsANarrow, + "A must be the narrow one since its the one that flows through " + "registers."); + + public: + static constexpr int PipelineStages = + compute_stage_count_or_override_single_affine_transformed_input< + sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale, + ElementZero, TileShape_MNK>(StageCountType{}); + + struct DispatchPolicy { + constexpr static int Stages = PipelineStages; + using ClusterShape = ClusterShape_MNK; + using Schedule = KernelScheduleType; + }; + + using GmemTiledCopyA = + decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = + decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + // ((T, V), (BlocksM, BlocksK), pipe) -> offset + using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutAtomARowMajor = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomScale = Layout< + Shape(SmemLayoutAtomARowMajor{})), cute::Int<1>>>; + + using SmemLayoutAtomB = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomB = void; + + // + // Validity checks + // + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + public: + // + // Type Aliases + // + using KernelSchedule = KernelScheduleType; + + // For cases where we can't have a void type, we can use this to allow the + // code to compile when the scale / zero is void. + using NonVoidElementScale = + cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = + cute::conditional_t, float, ElementZero>; + + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the + // code to compile when the scale is void. + using NonVoidStrideScale = + cute::conditional_t, + cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((cutlass::gemm::detail::is_k_major()), + "The transformed matrix (A) must be K-major."); + + static_assert((sizeof(ElementB) == 2) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element (matrix B) must be 2 bytes OR both " + "inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major " + "if B is scaled]."); + + static_assert(std::is_same_v, + "TiledMma::ValTypeC must be the same as ElementAccumulator."); + + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemCopyAtomScale = Copy_Atom; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any + // rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = + cute::conditional_t>>; + using InternalElementB = + cute::conditional_t>>; + + using TransformA = cute::identity; + using TransformB = cute::identity; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = + cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), + shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, + "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutACopy = decltype(tile_to_shape( + SmemLayoutAtomARowMajor{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), + Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major + // only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, + layout::ColumnMajor> && + cute::is_same_v, + layout::RowMajor>; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc " + "for this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // These two restrictions are related, so we place the assertions together. + // To relax them, we need to handle loading more than 1 row of scales for + // every main loop iteration. We must also handle updating the pipeline + // transaction bytes on the fly. NOTE: Deleting this assertion without + // required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, + "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible, not formatte for + // easier comparison + // clang-format off + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + // clang-format on + + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + make_shape(int32_t(0), int32_t(0), int32_t(0))))); + + using ATensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + shape(GmemLayoutA::TVbNbKL_to_offset( + make_shape(int32_t(0), int32_t(0), int32_t(0)))), + PrepackedStrideA{})); + + using BTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(StrideB{}, int32_t(0)), StrideB{})); + using ScaleTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + using ZeroTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { + return make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_scale( + ScaleTensor tensor_scale = ScaleTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_scale, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_zero( + ZeroTensor tensor_zero = ZeroTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_zero, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) { + return make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } + + public: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic + // clang-format off + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + // clang-format on + + // + // section setup end + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to + // define the TMA types + // Device side kernel params + struct Params { + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A()); + using TMA_Scale = decltype(make_tma_copy_scale()); + using TMA_Zero = decltype(make_tma_copy_zero()); + using TMA_B = decltype(make_tma_copy_B()); + + // required by outer loop: i.e. + // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here + // to handle the prepacked layout + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) { + return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride)); + }; + + typename Params::TMA_A tma_load_a; + typename Params::TMA_B tma_load_b; + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + + auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + tma_load_a = make_tma_copy_A( + make_logical_tensor(ptr_A, shape(layout), stride(layout))); + + tma_load_b = make_tma_copy_B( + make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); + + if constexpr (ModeHasScales) { + tma_load_scale = make_tma_copy_scale(make_logical_tensor( + args.ptr_S, make_shape(M, args.group_size, L), args.dS)); + } + + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + tma_load_zero = make_tma_copy_zero(make_logical_tensor( + args.ptr_Z, make_shape(M, args.group_size, L), args.dS)); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0}; + } else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + + return {tma_load_a, tma_load_b, tma_load_scale, + tma_load_zero, scale_k, args.group_size}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `SwapAB ? N : M -> M` since we dont support SwapAB + // clang-format off + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = M; + const int scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + + } + // clang-format off + + // Modified from upstream, should be kept close to that when possible + // the main difference is special handling for the prepacked A layout + // + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the + // contract Returned tuple must contain at least two elements, with the first + // two elements being: gA_mkl - The tma tensor, A after a local tile so it + // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local + // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be + // specified as needed by this collective. + // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the + // values within a prepacked block. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { + using X = Underscore; + auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL), + K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL); + + // (TILE_V,TILE_B,m,k,l) + auto make_gA_mkl = [&]() { + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); + return local_tile(mA_mkl, + make_shape(size<0>(layout), PPBlocksPerTile_MK{}), + make_coord(0, make_coord(_, _))); + }; + + // (TILE_N,TILE_K,n,k,l) + auto make_gB_nkl = [&]() { + Tensor mB_nkl = + mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); + return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gS_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gZ_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(), + make_gZ_mkl()); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in load_init."); + } + } + + // Similar to upstream, should be kept close to that when possible + // the main difference is in the layout comments + // clang-format off + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + // clang-format on + + // Modified from upstream, should be kept close to that when possible + // the main differences are handling the prepacked A layout, and separating + // the loading of A from upcoverting A + // + // Perform a collective-scoped matrix multiply-accumulate + // Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for " + "RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset + auto constexpr smem_A = SmemLayoutA{}; + + // convert: + // ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset + // to: + // (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset + // which can be thought of as: + // (T, MMA, (MMA_M, MMA_K), pipe) -> offset + auto constexpr smem_A_mma_ = + make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A), + zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A)); + // flatten to: + // (T, MMA, MMA_M, MMA_K, pipe) -> offset + auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), + smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate fragments and descriptors + Tensor tCrA_load = make_tensor( + tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K) + Tensor tCrA_mma = make_fragment_like(tCrA_load); + + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + static constexpr int A_CPY_VEC = + decltype(max_common_vector(tCsA, tCrA_load)){}; + + static constexpr int COVERSION_WIDTH = + std::min(A_CPY_VEC, int(size<0>(tCrA_mma))); + + auto load_A_to_registers = [&](int read_stage) { + copy(create_auto_vectorizing_copy(), + tCsA(_, _, _, read_stage), tCrA_load(_, _, _)); + }; + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = + partition_extra_mma_info(thread_mma, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info( + tiled_mma, partitioned_extra_info, warp_group_thread_idx); + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + auto convert_A = [&, a_vec = Int{}](int k_block, + int read_stage) { + load_extra_info_to_registers(partitioned_extra_info, + copy_partitions_extra_info, k_block, + read_stage); + transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info, + k_block); + }; + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + load_A_to_registers(read_stage); + convert_A(0, read_stage); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, smem_pipe_read.index()); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to + // overwrite the A registers for the first mma. + warpgroup_wait(); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, + // so we can release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } else { + convert_A(k_block + 1, read_stage); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, read_stage); + } + } + } + + warpgroup_fence_operand(accum); + } + + // Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it + ++smem_pipe_release; + } + } + + private: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Similar to `copy_A_and_extra_info` upstream, should be kept the same when + // possible + // the main differences this only loads the extra info into registers and + // not A (since we now preload more of A in the main pipeline) + // Load scales and zeros into registers if required + template + CUTLASS_DEVICE void load_extra_info_to_registers( + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, + int read_stage) { + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), + tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), + tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } + } + + // Similar to upstream, should be kept the same when possible. + // the main differences are that `convert_tensor` supports interleaved + // layouts and bfloat16 has been optimized. `transform_internal_A` has also + // been inlined for code simplicity. + // Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock( + TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, + int const k_block) { + auto in = tCrA_load(_, _, k_block); + auto out = tCrA_mma(_, _, k_block); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + convert_tensor(in, out, vec_A); + } else if constexpr (ModeHasScales) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto converted_inputs = + make_fragment_like(tCrA_mma)(_, _, k_block); + auto scales = tCrS(_, _, 0); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, vec_A); + // Apply scales and broadcast across inputs, store in converted_inputs + + // We need to cast to nv_bfloat16 for the multiply since + // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to + // float, which nvcc will not optimize to using vectorized fma + // instructions (i.e. hfma.bf16_v2) + if constexpr (std::is_same_v) { + cute::transform( + recast(converted_inputs), recast(scales), + recast(converted_inputs), cute::multiplies{}); + } else { + cute::transform(converted_inputs, scales, converted_inputs, + cute::multiplies{}); + } + + // Apply zeros if required + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto converted_zeros = make_fragment_like(tCrZ)(_, _, 0); + + convert_tensor(tCrZ(_, _, 0), converted_zeros); + if constexpr (std::is_same_v) { + cute::transform(recast(converted_inputs), + recast(converted_zeros), + recast(converted_inputs), cute::plus{}); + } else { + cute::transform(converted_inputs, converted_zeros, converted_inputs, + cute::plus{}); + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } else { + static_assert(cutlass::detail::dependent_false, + "No A data is loaded."); + } + } + + // Modified from upstream, should be kept the same when possible + // the main differences is that this version supports interleaved converts + // Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor( + Tensor const& in, + Tensor& out, + cute::Int width = {}) { + // This is an element-wise conversion where we expect both tensors to have + // the same layout. As a result, we can cast as a cutlass array to use the + // fast numeric converters without worrying about indexing into the layout. + constexpr int N = cosize_v; + + // The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, + "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, + "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, + "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), + "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, + "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + + using Converter = cutlass::InterleavedNumericArrayConverter< + IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = + reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = + reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } +}; + +} // namespace machete diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh new file mode 100644 index 00000000..046e6e5a --- /dev/null +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -0,0 +1,237 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "machete_collective_builder.cuh" +#include "machete_prepacked_layout.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +// NOTE This kernel computes D = alpha * A * B + beta * C by computing +// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma +// instructions only support sourcing from registers for the left-hand +// operand, we want to upconvert/decompress the quantized operand in +// register. Since the primary use case we want to support is Y = XW^t where +// W is quantized, in this situation or right-hand operand is quantized so +// we compute the transpose to move it to the left-hand side. +template +struct MacheteKernelTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementC = cute::conditional_t; + using ElementZ = ZeroT; + using ElementS = ScaleT; + + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementCompute = AccumulatorT; // For Epilogue + + using BTypeTuple = cute::conditional_t< + with_scales, + cute::conditional_t, + cute::tuple>, + ElementB>; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using LayoutScale = cutlass::layout::RowMajor; + // not actually used since B has the prepacked layout, but required by cutlass + using _LayoutB = cutlass::layout::ColumnMajor; + + // Interface strides expected by create_arguments (will get transposed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = cutlass::detail::TagToStrideA_t; + using StrideD = cutlass::detail::TagToStrideA_t; + using StrideS = cutlass::detail::TagToStrideA_t; + using StrideZ = StrideS; + + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using PrepackedLayoutB = + PrepackedLayoutBTemplate; + + static int constexpr TileShapeK = + 128 * 8 / cutlass::sizeof_bits::value; + static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentC = + (with_C) ? 128 / cutlass::sizeof_bits_v : 0; + static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; + + using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, + cute::Int{})); + using ClusterShape = typename ScheduleConfig::ClusterShape; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using TileScheduler = typename ScheduleConfig::TileScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::VLLMCollectiveBuilder< + cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, + BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // stride_B is unused (since B is prepacked), but still required by cutlass + using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; + + using Arguments = typename Gemm::Arguments; + using MainloopArguments = typename GemmKernel::MainloopArguments; + using EpilogueArguments = typename GemmKernel::EpilogueArguments; + + template + static Arguments create_arguments( + cudaStream_t stream, + ElementA const* A_ptr, // A is an MxK matrix + Layout const& layout_A, + ElementB const* B_ptr, // B is an KxN prepacked matrix + ElementD* D_ptr, // D is an MxN matrix + Layout const& layout_D, + ElementC const* C_ptr, // C is an MxN matrix + std::optional> const& layout_C, + ElementS const* S_ptr, // S is an scale_KxN matrix + std::optional> const& layout_S, + ElementZ const* Z_ptr, // Z is an scale_KxN matrix + std::optional> const& layout_Z, + ElementCompute alpha, ElementCompute beta, + std::optional maybe_group_size) { + static_assert(!with_zeropoints || with_scales); + + int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + + int const group_size = maybe_group_size.value_or(K); + int const scale_k = (K + group_size - 1) / group_size; + + TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); + TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); + + if constexpr (with_C) { + TORCH_CHECK(C_ptr && layout_C); + } else { + TORCH_CHECK(!C_ptr, "C not supported"); + } + + if constexpr (with_scales) { + TORCH_CHECK(S_ptr && layout_S); + TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + } else { + TORCH_CHECK(!S_ptr, "Scales not supported"); + } + + if constexpr (with_zeropoints) { + TORCH_CHECK(Z_ptr && layout_Z); + TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); + TORCH_CHECK(layout_S && *layout_Z == *layout_S, + "Scales and zeros must have the same layout"); + } else { + TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + } + + // Transpose A and D + // A doesn't need to be transposed since cutlass expects a NxK matrix + // for B (which is At) + auto stride_At = layout_A.stride(); + auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); + auto stride_Ct = stride_Dt; + if (layout_C) { + stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); + } + + MainloopArguments mainloop_arguments{}; + EpilogueArguments epilogue_arguments{ + {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + + if constexpr (with_scales && with_zeropoints) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_ptr, stride_S, group_size, Z_ptr}; + } else if constexpr (with_scales) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = MainloopArguments{ + B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + } else { + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; + } + + return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + mainloop_arguments, + epilogue_arguments}; + }; + + static size_t get_workspace_size(Arguments const& args) { + return Gemm::get_workspace_size(args); + } + + static bool can_implement(Arguments const& args) { + return Gemm::can_implement(args) == cutlass::Status::kSuccess; + } + + static void run(Arguments const& args, void* workspace, cudaStream_t stream) { + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(args, workspace, stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Machete kernel failed to initialize workspace"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); + } +}; + +}; // namespace machete diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh new file mode 100644 index 00000000..e2604d4b --- /dev/null +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -0,0 +1,95 @@ +#pragma once + +#include +#include + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +struct PyTorchArguments { + torch::Tensor const& A; + torch::Tensor const& B; + c10::optional const& scales; + c10::optional const& zeros; + c10::optional group_size; + c10::optional const& C; + c10::optional alpha; + c10::optional beta; + c10::optional schedule; +}; + +template +torch::Tensor run_impl(PyTorchArguments args) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); + + auto device = args.A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + using EleA = typename MacheteKernel::ElementA; + using EleB = typename MacheteKernel::ElementB; + using EleC = typename MacheteKernel::ElementC; + using EleD = typename MacheteKernel::ElementD; + using EleScale = typename MacheteKernel::ElementS; + using EleZero = typename MacheteKernel::ElementZ; + + using StrideA = typename MacheteKernel::StrideA; + using StrideC = typename MacheteKernel::StrideC; + using StrideD = typename MacheteKernel::StrideD; + using StrideS = typename MacheteKernel::StrideS; + using StrideZ = typename MacheteKernel::StrideZ; + + int M = args.A.size(0); + int N = args.B.size(1); + int K = args.A.size(1); + + // Allocate output + torch::Tensor D = + torch::empty({M, N}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + + auto const &A = args.A, &B = args.B; + auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_C = maybe_make_cute_layout(C, "C"); + auto layout_S = maybe_make_cute_layout(scales, "scales"); + auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); + + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); + auto S_ptr = + static_cast(scales ? scales->const_data_ptr() : nullptr); + auto Z_ptr = + static_cast(zeros ? zeros->const_data_ptr() : nullptr); + + auto arguments = MacheteKernel::create_arguments( + stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, + layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), + args.group_size.value_or(K)); + TORCH_CHECK(MacheteKernel::can_implement(arguments), + "Machete kernel cannot be run with these arguments"); + + size_t workspace_size = MacheteKernel::get_workspace_size(arguments); + torch::Tensor workspace = torch::empty( + workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); + + MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream); + + return D; +}; + +template +struct GemmDispatcher { + static torch::Tensor dispatch(PyTorchArguments args); + static std::vector supported_schedules(); +}; + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh new file mode 100644 index 00000000..8e021045 --- /dev/null +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -0,0 +1,62 @@ +#pragma once + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +static __global__ void prepack_B_kernel(BInTensor B_in, + BTiledOutTensor B_tiled_out) { + auto tB_in = local_tile(B_in, TileShapeNKL{}, + make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); + auto tB_out = B_tiled_out(make_coord(_, _), + make_coord(blockIdx.x, blockIdx.y), blockIdx.z); + + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout, Stride<_32, _1>>{}, + Layout>{}); + + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + Tensor thr_tile_S = thr_copy.partition_S(tB_in); + Tensor thr_tile_D = thr_copy.partition_D(tB_out); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition + auto fragment = make_tensor(shape(thr_tile_D)); + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(tiled_copy, thr_tile_S, fragment); + copy(Copy_Atom{}, fragment, thr_tile_D); +} + +template +static void prepack_B(cudaStream_t stream, + typename PrepackedLayoutB::ElementB const* B_in_ptr, + InLayout B_layout, + typename PrepackedLayoutB::ElementB* B_out_ptr) { + using TileShapeNKL = + decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); + auto ilvd_NKbNbKL_to_offset = + PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout)); + + TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); + + auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); + auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + + auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); + auto B_tiled_out = + make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); + + prepack_B_kernel + <<>>(B_in, B_tiled_out); +} + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh new file mode 100644 index 00000000..686dd68b --- /dev/null +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -0,0 +1,71 @@ +#pragma once + +#include "machete_prepack_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +torch::Tensor prepack_impl(torch::Tensor const B) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); + using ElementB = typename PrepackedLayoutB::ElementB; + using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK; + + auto device = B.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto B_ptr = static_cast(B.const_data_ptr()); + // elements per storage item for B + auto eles_per_storage = + (B.dtype().itemsize() * 8) / cute::sizeof_bits_v; + + // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to + // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) + auto Bt_packed = B.t(); + + TORCH_CHECK( + (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0, + "B.shape[0] (in terms of unpacked elements) must be a multiple of ", + size<1>(PPBlockShape_NK{})); + TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0, + "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{})); + + using StrideB = cutlass::detail::TagToStrideB_t; + auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); + + // convert (N,packed_K,L) layout to (N,K,L) layout + // in effect we want to do: blocked_product(layout_Bt_packed, + // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}), + // Step<_1, _0, _2>{})); + // but blocked_product does not support dynamic strides so we implement the + // equivalent manually, + // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L) + // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage) + // when s1 == 1 + TORCH_CHECK(stride<1>(l_Bt_packed) == 1); + // clang-format off + auto const layout_Bt = make_layout( + transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) { + return idx == 1 ? ele * eles_per_storage : ele; + }), + transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) { + return idx != 1 ? ele * eles_per_storage : ele; + })); + // clang-format on + + // Allocate output + torch::Tensor D = torch::empty_like(B); + + prepack_B(stream, B_ptr, layout_Bt, + static_cast(D.mutable_data_ptr())); + + return D; +}; + +template +struct PrepackBDispatcher { + static torch::Tensor dispatch(torch::Tensor B); +}; + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh new file mode 100644 index 00000000..78e2cc5e --- /dev/null +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -0,0 +1,220 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "machete_collective_builder.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +struct IlvBlkLayoutAuto {}; + +// This defines a prepacked layout for the B matrix, where the matrix is broken +// up into PPBlockShape_NK blocks. The data within each block is then compactly +// stored in memory such that when performing a TiledMMA operation with the same +// shape as prepacked block, all the data for a given thread is contiguous in +// memory. This allows us to use wider shared memory loads when loading B from +// shared memory. The values within a thread are also potentially interlaeved +// inorder to allow for more efficient upconverting. +// +// The contract here is that the `TiledMma` determined below matches the one +// ultimately used in the kernel. (this is also why the other element types are +// required along with the kernel schedule) +template +// clang-format on +struct PrepackedLayoutBTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementMma = MmaType; + + // Only use interleaved layouts for subbyte weights, prmt instructions makes + // non-interleaved layouts for 8bit+ weights efficient enough we don't need + // iterleaved layouts + using IlvdBlkLayout = std::conditional_t< + std::is_same_v, + std::conditional_t <= 4, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, + IlvBlkLayout_>; + + // TODO (LucasWilkinson): compare the performance for other sizes + // Prepacked block shape, smallest layout atom for loading into registers + // (can contain multiple wgmma instructions worth of data in one block) + // We ideally want this to be configured such that a thread can perform 128bit + // loads, i.e. we amount of data associated with each thread within a + // prepacked block is a multiple of 128bits, when using a cooperative sechdule + // we have 256 threads working a single block at a time, this means each + // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, + // for a 4bit type this would be 128bits + using PPBlockShape_NK = Shape<_128, _64>; + + // Create the shape of the tile anticipated to be used by the GEMM kernel, + // when the kernel executes we will compute `Ct = Bt * At` since the + // quantized weights (B), must be the lhs operand so the flow through + // registers. + // The _128 here doesn't actually impact the shape of the stored tile directly + // but may impact the op selected by rs_op_selector + using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{}, + size<1>(PPBlockShape_NK{}))); + + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + // Prepacked block, (athrid, val) -> (N,K) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() { + return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{})); + } + + // Prepacked block, (N,K) -> (athrid, val) + // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() { + return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() { + // Return iterleaved layout + return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() { + auto layout_no_interleave = + make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + + if constexpr (std::is_same_v) { + return layout_no_interleave; + } else { + // interleave by transforming FrgV into interleaved blocks where each + // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is + // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4) + // if FrgV is {A, B, C, D, E, F, G, H} + // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} + auto frgV = get<1, 0>(layout_no_interleave); + auto ilvdBlk = IlvdBlkLayout{}; + static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + auto ilvd_FrgV = make_layout( + make_shape(shape(ilvdBlk), Int{}), + make_stride(stride(ilvdBlk), size(ilvdBlk))); + + // Return iterleaved layout + return make_layout( + get<0>(layout_no_interleave), + make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave))); + } + } + + // Prepacked block, (M,K) -> (storage_offset) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() { + // do (M,K) -> (athrid, val) -> (storage_idx) + return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV()); + } + + // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_TV_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) + // => ((athrid, val), (BlocksN, BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_ilvd_NK_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN, + // BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { + auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})), + make_layout(size<1>(PPBlockShape_NK{}))); + + // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L) + auto tiled_A = zipped_divide(make_layout(shape_mkl), tile); + return tiled_A.compose(ppblock_TV_to_NK(), _); + } + + // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L) + template + CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) { + auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl); + return blocked_product(ppblock_NK_to_TV(), + make_layout(shape<1>(TVbNbK_to_NKL_layout))); + } +}; + +}; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu new file mode 100644 index 00000000..ef36a490 --- /dev/null +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -0,0 +1,79 @@ +#include "machete_mm_launcher.cuh" +#include "machete_prepack_launcher.cuh" +#include "core/scalar_type.hpp" + +namespace machete { + +using namespace vllm; + +// +// Utils (type dispatching) +// + +template +static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { + if (type == vllm::kU4) { + return fn(cutlass::uint4b_t{}); + } else if (type == vllm::kU8) { + return fn(cutlass::uint8_t{}); + } else if (type == vllm::kU4B8) { + return fn(cutlass::vllm_uint4b8_t{}); + } else if (type == vllm::kU8B128) { + return fn(cutlass::vllm_uint8b128_t{}); + } else { + TORCH_CHECK(false, "Unsupported type ", type.str()); + } +} + +#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ + AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) + +// +// Interface +// + +std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { + return scalar_type_dispatch(*btype, [&](auto BType) { + return GemmDispatcher::supported_schedules(); + }); +} + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule) { + auto args = PyTorchArguments{.A = A, + .B = B, + .scales = scales, + .zeros = zeros, + .group_size = group_size, + .C = C, + .alpha = alpha, + .beta = beta, + .schedule = schedule}; + + return scalar_type_dispatch(*btype, [&](auto BType) { + return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( + A.scalar_type(), "machete_gemm", [&] { + using ComputeType = equivalent_cutlass_type_t; + return GemmDispatcher::dispatch(args); + }); + }); +} + +torch::Tensor prepack_B(torch::Tensor const& B, + ScalarTypeTorchPtr const& btype) { + return scalar_type_dispatch(*btype, [&](auto BType) { + return PrepackBDispatcher::dispatch(B); + }); +} + +}; // namespace machete diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e26c2e28..6d1f53b7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -133,6 +133,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. + ops.def("machete_supported_schedules", &machete::supported_schedules); + ops.def( + "machete_gemm(Tensor A, Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype," + " Tensor? scales, Tensor? zeros, int? group_size," + " Tensor? C, float? alpha, float? beta, str? schedule)" + "-> Tensor"); + ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); + ops.def( + "machete_prepack_B(Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype)" + "-> Tensor"); + ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py new file mode 100644 index 00000000..dadf5944 --- /dev/null +++ b/tests/kernels/test_machete_gemm.py @@ -0,0 +1,272 @@ +"""Tests for the machete kernel. + +Run `pytest tests/kernels/test_machete_gemm.py`. +""" + +import math +from typing import Optional, Tuple + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (13, 8192, 4096), + (26, 4096, 8192), + (1, 4096, 4096), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (64, 4096, 4096), +] + +ACT_TYPES = [torch.float16, torch.bfloat16] +WTYPE_ZEROPOINTS = [ + # GPTQ style + (scalar_types.uint4b8, False), + (scalar_types.uint8b128, False), + # AWQ style + (scalar_types.uint4, True), + (scalar_types.uint8, True), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + + +def rand_data(shape, dtype=torch.float16): + return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def machete_quantize_and_pack(w: torch.Tensor, + wtype: ScalarType, + group_size: int, + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + w_q_machete = ops.machete_prepack_B(w_q, wtype) + + return w_ref, w_q_machete, w_s, w_zp + + +def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, + wtype: ScalarType, group_size: int, + zero_points: bool): + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + output = ops.machete_gemm( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) +@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) +@pytest.mark.parametrize("group_size", [128, None]) +def test_machete_all_schedules(shape, atype: torch.dtype, + wtype_zeropoints: Tuple[ScalarType, bool], + group_size: Optional[int]): + m, n, k = shape + wtype, zero_points = wtype_zeropoints + + if group_size is not None and k % group_size != 0: + return + + print(f"MNK = {m} {n} {k}") + + # Normalize group_size + if group_size is None: + group_size = k + assert group_size <= k + + a = rand_data((m, k), atype) + w = rand_data((k, n), atype) + + w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( + w, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + for schedule in ops.machete_supported_schedules(wtype): + output = ops.machete_gemm( + a, + b_q=w_q_machete, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + schedule=schedule, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ + f"Schedule failed {schedule}" + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) +@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) +@pytest.mark.parametrize("group_size", [128, None]) +def test_machete_heuristic(shape, atype: torch.dtype, + wtype_zeropoints: Tuple[ScalarType, bool], + group_size: Optional[int]): + m, n, k = shape + wtype, zero_points = wtype_zeropoints + + if group_size is not None and k % group_size != 0: + return + + # Normalize group_size + if group_size is None: + group_size = k + assert group_size <= k + + a = rand_data((m, k), atype) + b = rand_data((k, n), atype) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + m, n, k = 512, 4096, 4096 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + print(f"MNK = {m} {n} {k}, device = {device}") + + a = rand_data((m, k), torch.float16).to(device) + b = rand_data((k, n), torch.float16).to(device) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + whole_a = rand_data((big_m, big_k), torch.float16) + whole_b = rand_data((big_k, big_n), torch.float16) + + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_gemm(**self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1f0a111a..b89a90ef 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -329,6 +329,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, num_bits, size_m, size_n, size_k) +# machete +def machete_supported_schedules(b_type: ScalarType) -> List[str]: + return torch.ops._C.machete_supported_schedules(b_type) + + +def machete_gemm( + a: torch.Tensor, + b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros, + b_group_size, c, alpha, beta, schedule) + + +def machete_prepack_B(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, b_type) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 7f9081b2..33f24ff5 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, - zero_points: bool = False): + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" @@ -126,7 +127,13 @@ def quantize_weights(w: torch.Tensor, w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and zero_points: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s if quant_type.has_bias(): w_q += quant_type.bias