diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c946fc5..c823c9ff 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.6.0 + GIT_TAG v3.7.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -299,7 +299,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index d0353bc8..b87496ca 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -3,7 +3,7 @@ import copy import itertools import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from typing import Callable, Iterable, List, Optional, Tuple import torch import torch.utils.benchmark as TBenchmark @@ -12,6 +12,8 @@ from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_block_fp8_matmul) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -38,8 +40,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ).blocked_autorange(min_run_time=min_run_time) -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_int8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) @@ -48,155 +57,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), + "cutlass_i8_i8_bf16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), + "cutlass_i8_i8_bf16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, + bias), + "cutlass_i8_i8_bf16_scaled_mm_azp": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, None, bias), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, azp), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, azp, bias), + } + timers = [] - # pytorch impl - bfloat16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) - - # pytorch impl - float16 - timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) - - # cutlass impl - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - - # cutlass with bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) - - # cutlass with azp per-tensor - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj)) - - # cutlass with azp per-tensor + bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, None, bias)) - - # cutlass with azp per-token - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, azp)) - - # cutlass with azp per-token + bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, azp, bias)) + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) return timers -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_fp8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + block_scale_a = torch.rand((m, k // 128), + device="cuda", + dtype=torch.float32) + block_scale_b = torch.rand((k // 128, n // 128), + device="cuda", + dtype=torch.float32) + block_scale_a_M_major = block_scale_a.t().contiguous().t() + block_scale_b_K_major = block_scale_b.t().contiguous().t() bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + print(m, k, n) + + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), + "pytorch_fp8_fp8_fp16_scaled_mm": + lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": + lambda: torch._scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.float16, + use_fast_accum=True), + "pytorch_fp8_fp8_bf16_scaled_mm": + lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": + lambda: torch._scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True), + "cutlass_fp8_fp8_bf16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), + "cutlass_fp8_fp8_fp16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, + bias), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16)), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": + lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a, + block_scale_b.t(), (128, 128)), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": + lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major, + block_scale_b_K_major, torch.float16), + } + timers = [] - - # pytorch impl w. bf16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) - - # pytorch impl: bf16 output, without fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16)) - - # pytorch impl: bf16 output, with fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True)) - - # pytorch impl: fp16 output, without fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16)) - - # pytorch impl: fp16 output, with fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True)) - - # cutlass impl: bf16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - # cutlass impl: fp16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) - - # cutlass impl: bf16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) - - # cutlass impl: fp16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16))) + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) return timers -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench(dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) + return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) + return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels) raise ValueError("unsupported type") @@ -207,18 +193,22 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + MKNs: Iterable[Tuple[int, int, int]], + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") + timers = bench(dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels) print_timers(timers) results.extend(timers) - return results -# output makers def make_output(data: Iterable[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, @@ -232,15 +222,11 @@ def make_output(data: Iterable[TMeasurement], pkl.dump(data, f) -# argparse runners - - def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, MKNs) - + data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -251,8 +237,7 @@ def run_range_bench(args): Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes MKNs = list(zip(Ms, Ks, Ns)) - data = run(args.dtype, MKNs) - + data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"range_bench-{args.dtype}") @@ -278,7 +263,7 @@ def run_model_bench(args): for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, MKNs) + data = run(args.dtype, MKNs, bench_kernels=args.kernels) model_bench_data.append(data) # Print all results @@ -328,6 +313,15 @@ Benchmark Cutlass GEMM. type=to_torch_dtype, required=True, help="Available options are ['int8', 'fp8']") + parser.add_argument( + "--kernels", + nargs="+", + type=str, + default=None, + help= + "Exact names of the kernels to benchmark. If not set, runs all kernels." + ) + subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") @@ -362,4 +356,4 @@ Benchmark Cutlass GEMM. model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) \ No newline at end of file + args.func(args) diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ba9f40a2..ddfaca27 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -1,7 +1,14 @@ +#pragma once + #include #include -inline uint32_t next_pow_2(uint32_t const num) { +inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { + return (a + b - 1) / b; } \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 07c9e46c..febc4ecc 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -32,3 +32,20 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { } int32_t get_sm_version_num(); + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; \ No newline at end of file diff --git a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp new file mode 100644 index 00000000..ec75c29e --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp @@ -0,0 +1,123 @@ +// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl +// clang-format off +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" + +#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS (BlockScaled Builders) +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + int ScaleGranularityM +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, + cute::enable_if_t< + not detail::is_use_rmem_A()> +> { + using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; + + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert((!IsFP8Input || !IsArrayOfPointersGemm), + "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsCooperative = cute::is_any_of_v>; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp new file mode 100644 index 00000000..13b90e99 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp @@ -0,0 +1,183 @@ +// clang-format off +// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/algorithm/clear.hpp" +#include "cute/tensor.hpp" + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////FP8 Accumulation/////////////////////////// +////////////////////////////////////////////////////////////////////////////// +/// This class provides API to promote (add) or scale (multiply_add) the results +/// from the tensor core accumulators to the main accumulators when the number +/// of MMAs reaches the max number of MMA interval specified by user, after that +/// the tensor core accumulators are zeroed. +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class EngineAccum, + class LayoutAccum> +struct GmmaFP8AccumulationWithScale { + using TensorAccum = cute::Tensor; + using ElementAccumulator = typename EngineAccum::value_type; + + static_assert(is_static::value, "Accumulator Layout should be static"); + static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + +private: + TensorAccum& accum_; + TensorAccum accum_temp_; + + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. + uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + + // promote or `add` the partial accumulators to main accumulator (FADD). + CUTLASS_DEVICE + void promote_core() { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i); + } + } + + // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scale) { + using TensorScale = cute::Tensor; + + static_assert(is_static::value, "Scale Layout should be static"); + static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale(i); + } + } + +public: + CUTLASS_DEVICE + GmmaFP8AccumulationWithScale( + TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) + : accum_(accum), + accum_promotion_interval_(accum_promotion_interval), + mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), + mma_count_(0), + reset_accum_flag_(0) + { + accum_temp_ = cute::make_fragment_like(accum); + } + + // + // Methods (Common) + // + + CUTLASS_DEVICE + TensorAccum& operator()() { + return accum_temp_; + } + + /// prepare the MMA accumulators when initialization or zeroing is required. + CUTLASS_DEVICE + bool prepare_if_needed() { + return reset_accum_flag_; + } + + // + // Methods (for FADD version) + // + + /// promote (add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_if_needed() { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + promote_core(); + mma_count_ = 0; + } + } + + /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_residue_if_needed() { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + promote_core(); + } + } + + // + // Methods (for FFMA version) + // + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + + /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } +}; + +} // namespace cutlass::gemm::collective diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 00000000..928a9500 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,730 @@ +// clang-format off +// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + int ScaleGranularityM_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + // Two threads per CTA are producers (1 for operand tile and 32 for scales) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; // mxk + cute::array_aligned> smem_B; // nxk + cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k + cute::array_aligned> smem_scale_B; // 1xk + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_scale_A, + args.ptr_scale_B + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + constexpr auto scales_m = Int{}; + auto tM = get<2>(gA_mkl.shape()); + auto tN = get<2>(gB_nkl.shape()); + auto tK = get<3>(gA_mkl.shape()); + + // Make the tiled views of scale tensors + auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) + auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); + + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and + // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mScaleA_mkl = get<2>(load_inputs); + Tensor mScaleB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mScaleA_mkl.shape()); + + Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); + + Tensor gScaleA = local_tile( + mScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile( + cScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + + // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); // (1,1,1) + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); + + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); + + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Allocate predicate tensors for a_scales (since we can't guarantee that + // all scales are valid, since we could have a partial tiles along M) + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); + #pragma unroll + for (int i = 0; i < size(tApA_ScaleA); ++i) { + tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + // Copy operands A and B from global memory to shared memory + if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + // Copy scale tensors from global memory to shared memory + copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + + using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above + + Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) + ElementBlockScale scale_b; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers. + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp new file mode 100644 index 00000000..df809e27 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "cutlass/gemm/dispatch_policy.hpp" + +namespace cutlass::gemm { + +////////////////////////////////////////////////////////////////////////////// + +// FP8 related policies (including Blocked Scaled Accumulation) +// `ScaleGranularityM` specifies scaling granularity along M, while zero-value +// `ScaleGranularityM` indicates that scaling granularity is +// `size<0>(TileShape_MNK{})` along M. +template +struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum + : KernelTmaWarpSpecializedCooperative {}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp +// specialized dynamic schedule For FP8 kernels with Block Scaling +template , + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = + 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > +struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 + : MainloopSm90TmaGmmaWarpSpecialized { + static_assert( + cute::is_same_v< + KernelSchedule, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>>, + "KernelSchedule must be one of the warp specialized policies"); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh index 085ee129..e7fbba4c 100644 --- a/csrc/cutlass_extensions/vllm_collective_builder.cuh +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" namespace cutlass::gemm::collective { using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh new file mode 100644 index 00000000..9ac7eee7 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -0,0 +1,93 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" +// clang-format on + +namespace vllm::c3x { + +static inline cute::Shape get_problem_shape( + torch::Tensor const& a, torch::Tensor const& b) { + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + return {m, n, k, 1}; +} + +template +void cutlass_gemm_caller(torch::Device device, + cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args) { + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(device); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = cute::Stride, int64_t>; + using StrideB = cute::Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, cute::Int<1>{}, 0}; + StrideB b_stride{ldb, cute::Int<1>{}, 0}; + StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + + typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +} // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh similarity index 51% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh rename to csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index d4bc2f0a..9227ebb7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -2,9 +2,6 @@ // clang-format will break include orders // clang-format off -#include - -#include #include "cutlass/cutlass.h" @@ -32,21 +29,6 @@ using namespace cute; namespace vllm { -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); -#endif - } -}; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -101,60 +83,4 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu new file mode 100644 index 00000000..4cd38f49 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias) { + if (azp) { + return cutlass_scaled_mm_sm90_int8_epilogue< + c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, + *azp, bias); + } else { + return cutlass_scaled_mm_sm90_int8_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu new file mode 100644 index 00000000..0501e6da --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu @@ -0,0 +1,24 @@ + +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh new file mode 100644 index 00000000..fb7a82b8 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -0,0 +1,168 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template > +struct cutlass_3x_gemm_fp8_blockwise { + using GroupSizeM = Int; + using GroupSizeN = Int; + using GroupSizeK = Int; + using TileSizeM = Int; + + static_assert(TileSizeM_ % GroupSizeM_ == 0, + "TileSizeM must be a multiple of GroupSizeM"); + + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using StrideD = Stride, Int<0>>; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using StrideC = StrideD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementBlockScale = float; + using ElementCompute = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = Shape; + + using KernelSchedule = cutlass::gemm:: + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + GroupSizeM_>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, + ElementD, StrideD, AlignmentD, EpilogueSchedule, + StoreEpilogueCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, + LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + auto prob_shape = c3x::get_problem_shape(a, b); + int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), + k = get<2>(prob_shape); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + // Check is the t is contiguous and is 1D or 2D with one of the dimensions + // being 1 (i.e. a row or column vector) + auto is_contiguous_vector = [](const torch::Tensor& t) { + auto t_sizes = t.sizes(); + return t.is_contiguous() && + (t.dim() == 1 || + (t.dim() == 2 && + *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); + }; + + // TODO(lucas): lets clean-up the kernel so that we pass in Strides so + // we don't have to deal with enforcing implicit layouts + TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); + TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), + "a_scales must be M major"); + TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); + TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), + "b_scales must be K major"); + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + cutlass_gemm_caller_blockwise< + cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, + b_scales); +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp new file mode 100644 index 00000000..7ede9e06 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu new file mode 100644 index 00000000..e092c61a --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm90_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm90_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh similarity index 76% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh rename to csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index f08419b3..32ea5db3 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -1,6 +1,7 @@ #pragma once -#include "scaled_mm_c3x.cuh" +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" /** * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm @@ -9,6 +10,8 @@ namespace vllm { +using c3x::cutlass_gemm_caller; + template typename Epilogue> struct sm90_fp8_config_default { @@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, } } +template