250 lines
10 KiB
Plaintext
250 lines
10 KiB
Plaintext
#pragma once
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/all.h>
|
|
|
|
// clang-format off
|
|
// The cutlass include order matters (annoyingly)
|
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
#include "cute/tensor.hpp"
|
|
#include "cutlass/tensor_ref.h"
|
|
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
|
// clang-format on
|
|
|
|
#include "cutlass_extensions/cute_utils.cuh"
|
|
#include "machete_collective_builder.cuh"
|
|
#include "machete_interleaving_utils.cuh"
|
|
|
|
namespace machete {
|
|
|
|
using namespace cute;
|
|
|
|
struct IlvBlkLayoutAuto {};
|
|
|
|
// This defines a prepacked layout for the B matrix, where the matrix is broken
|
|
// up into PPBlockShape_NK blocks. The data within each block is then compactly
|
|
// stored in memory such that when performing a TiledMMA operation with the same
|
|
// shape as prepacked block, all the data for a given thread is contiguous in
|
|
// memory. This allows us to use wider shared memory loads when loading B from
|
|
// shared memory. The values within a thread are also potentially interlaeved
|
|
// inorder to allow for more efficient upconverting.
|
|
//
|
|
// The contract here is that the `TiledMma` determined below matches the one
|
|
// ultimately used in the kernel. (this is also why the other element types are
|
|
// required along with the kernel schedule)
|
|
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
|
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
|
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
|
// clang-format on
|
|
struct PrepackedLayoutBTemplate {
|
|
using MmaType = ElementA_;
|
|
using ElementA = ElementA_;
|
|
using ElementB = ElementB_;
|
|
using ElementAccumulator = AccumulatorT;
|
|
using ElementMma = MmaType;
|
|
|
|
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
|
// in those cases case we use a LUT using prmt instructions to upconvert and
|
|
// is more efficient if the data is not interleaved For 8bit+ prmt
|
|
// instructions makes non-interleaved layouts efficient enough we don't need
|
|
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
|
static constexpr bool should_interleave =
|
|
sizeof_bits_v<ElementB> <= 4 &&
|
|
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
|
!std::is_same_v<ElementConvert_, int8_t>;
|
|
|
|
// Only use interleaved layouts for subbyte weights,
|
|
using IlvdBlkLayout = std::conditional_t<
|
|
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
|
std::conditional_t<
|
|
should_interleave,
|
|
decltype(get_interleaved_blk_layout<
|
|
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
|
void>,
|
|
IlvBlkLayout_>;
|
|
|
|
// TODO (LucasWilkinson): compare the performance for other sizes
|
|
// Prepacked block shape, smallest layout atom for loading into registers
|
|
// (can contain multiple wgmma instructions worth of data in one block)
|
|
// We ideally want this to be configured such that a thread can perform 128bit
|
|
// loads, i.e. we amount of data associated with each thread within a
|
|
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
|
|
// we have 256 threads working a single block at a time, this means each
|
|
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
|
|
// for a 4bit type this would be 128bits
|
|
using PPBlockShape_NK = Shape<_128, _64>;
|
|
|
|
// Create the shape of the tile anticipated to be used by the GEMM kernel,
|
|
// when the kernel executes we will compute `Ct = Bt * At` since the
|
|
// quantized weights (B), must be the lhs operand so the flow through
|
|
// registers.
|
|
// The _128 here doesn't actually impact the shape of the stored tile directly
|
|
// but may impact the op selected by rs_op_selector
|
|
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
|
|
size<1>(PPBlockShape_NK{})));
|
|
|
|
static constexpr cute::GMMA::Major GmmaMajorB =
|
|
gmma_rs_tag_to_major_B<LayoutB>();
|
|
|
|
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
|
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
|
using AtomLayoutMNK = cute::conditional_t<
|
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
|
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
|
|
|
using TiledMma = decltype(cute::make_tiled_mma(
|
|
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
|
|
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
|
|
AtomLayoutMNK{}));
|
|
|
|
// Prepacked block, (athrid, val) -> (N,K)
|
|
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
|
|
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
|
|
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
|
|
}
|
|
|
|
// Prepacked block, (N,K) -> (athrid, val)
|
|
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
|
|
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
|
|
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
|
|
}
|
|
|
|
// Prepacked block, (athrid, val) -> (storage_offset)
|
|
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
|
|
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
|
|
// Return iterleaved layout
|
|
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
|
}
|
|
|
|
// Prepacked block, (athrid, val) -> (storage_offset)
|
|
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
|
|
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
|
|
auto layout_no_interleave =
|
|
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
|
|
|
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
|
|
return layout_no_interleave;
|
|
} else {
|
|
// interleave by transforming FrgV into interleaved blocks where each
|
|
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
|
|
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
|
|
// if FrgV is {A, B, C, D, E, F, G, H}
|
|
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
|
auto frgV = get<1, 0>(layout_no_interleave);
|
|
auto ilvdBlk = IlvdBlkLayout{};
|
|
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
|
"FrgV must be divisible by size(ilvdBlk)");
|
|
auto ilvd_FrgV = make_layout(
|
|
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
|
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
|
|
|
// Return iterleaved layout
|
|
return make_layout(
|
|
get<0>(layout_no_interleave),
|
|
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
|
|
}
|
|
}
|
|
|
|
// Prepacked block, (M,K) -> (storage_offset)
|
|
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
|
|
// do (M,K) -> (athrid, val) -> (storage_idx)
|
|
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
|
|
}
|
|
|
|
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
|
|
template <typename Shape_NKL>
|
|
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
|
|
Shape_NKL shape_mkl) {
|
|
constexpr auto block_layout = ppblock_TV_to_offset();
|
|
|
|
// (BlocksN, BlocksK, L)
|
|
auto blocks_shape =
|
|
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
|
[](auto x, auto y) { return x / y; });
|
|
|
|
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
|
auto result = make_layout(
|
|
block_layout,
|
|
make_layout(blocks_shape,
|
|
compact_col_major(blocks_shape, size(block_layout))));
|
|
|
|
// ((athrid, val), (BlocksN, BlocksK, L))
|
|
// => ((athrid, val), (BlocksN, BlocksK), L)
|
|
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
|
}
|
|
|
|
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
|
template <typename Shape_NKL>
|
|
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
|
Shape_NKL shape_mkl) {
|
|
auto layout = TVbNbKL_to_offset(shape_mkl);
|
|
return make_layout(coalesce(get<0>(layout)), get<1>(layout),
|
|
get<2>(layout));
|
|
}
|
|
|
|
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
|
template <typename Shape_NKL>
|
|
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
|
Shape_NKL shape_mkl) {
|
|
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
|
|
|
|
// (BlocksN, BlocksK, L)
|
|
auto blocks_shape =
|
|
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
|
[](auto x, auto y) { return x / y; });
|
|
|
|
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
|
auto result = make_layout(
|
|
block_layout,
|
|
make_layout(blocks_shape,
|
|
compact_col_major(blocks_shape, size(block_layout))));
|
|
|
|
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
|
|
// BlocksK), L)
|
|
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
|
}
|
|
|
|
// (BlocksN, BlocksK, L) -> (storage_idx)
|
|
template <typename Shape_NKL>
|
|
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
|
// (BlocksN, BlocksK, L)
|
|
auto blocks_shape =
|
|
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
|
[](auto x, auto y) { return x / y; });
|
|
auto stride = size(PPBlockShape_NK{});
|
|
|
|
// (BlocksN, BlocksK, L) -> (storage_idx)
|
|
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
|
}
|
|
|
|
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
|
template <class Shape_NKL>
|
|
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
|
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
|
|
make_layout(size<1>(PPBlockShape_NK{})));
|
|
|
|
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
|
|
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
|
|
return tiled_A.compose(ppblock_TV_to_NK(), _);
|
|
}
|
|
|
|
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
|
|
template <class Shape_NKL>
|
|
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
|
|
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
|
|
return blocked_product(ppblock_NK_to_TV(),
|
|
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
|
|
}
|
|
};
|
|
|
|
}; // namespace machete
|