vllm/csrc/quantization/machete/machete_mm_kernel.cuh
Lucas Wilkinson 86e9c8df29
[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2024-09-23 13:46:26 -04:00

239 lines
9.6 KiB
Plaintext

#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C,
bool with_scales, bool with_zeropoints>
struct MacheteKernelTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue
using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
} else {
TORCH_CHECK(!C_ptr, "C not supported");
}
if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
}
if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
if constexpr (with_scales && with_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete