[Kernel] Add w8a8 CUTLASS kernels (#4749)
This commit is contained in:
parent
8435b207af
commit
2060e93659
@ -173,6 +173,16 @@ set(VLLM_EXT_SRC
|
||||
"csrc/pybind.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(FetchContent)
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# CUTLASS 3.5.0
|
||||
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||
)
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
@ -180,7 +190,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/custom_all_reduce.cu")
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
|
||||
|
||||
#
|
||||
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
||||
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
|
||||
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
||||
set_source_files_properties(
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
@ -190,6 +214,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
|
@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack(
|
||||
int64_t size_k,
|
||||
int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
int cutlass_scaled_mm_dq(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
|
@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization.");
|
||||
#endif
|
||||
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
|
12
csrc/quantization/cutlass_w8a8/common.hpp
Normal file
12
csrc/quantization/cutlass_w8a8/common.hpp
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(status)) \
|
||||
}
|
@ -0,0 +1,340 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||
// https://github.com/NVIDIA/cutlass It's beem modified to support either
|
||||
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
|
||||
// Important because this saves us a factor 4x on the number of kernels
|
||||
// compiled.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
// clang-format on
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
Element null_default = Element(0);
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->ptr_row) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if(get<1>(coord_v(i)) < n)
|
||||
{
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
Element null_default = Element(0);
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gCol,
|
||||
RTensor&& tC_rCol,
|
||||
CTensor&& tC_cCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
||||
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
||||
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gCol;
|
||||
RTensor tC_rCol;
|
||||
CTensor tC_cCol;
|
||||
Params const* params_ptr;
|
||||
int m;
|
||||
|
||||
// This function is modified from VisitorColBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rCol);
|
||||
|
||||
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||
}
|
||||
|
||||
if (params_ptr->ptr_col) {
|
||||
// In this case we are loading from a column vector and broadcasting
|
||||
copy_if(pred, tC_gCol, tC_rCol);
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
auto dst_v = filter(tC_rCol);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(dst_v); ++i) {
|
||||
if(pred(i)){
|
||||
dst_v(i) = params_ptr->null_default;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
||||
return frg_col;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mCol = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_col),
|
||||
problem_shape,
|
||||
params_ptr->dCol);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
||||
Tensor tC_gCol = group_modes<1,4>(
|
||||
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tC_cCol = group_modes<1,4>(
|
||||
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gCol), decltype(tC_rCol),
|
||||
decltype(tC_cCol), ProblemShape>(
|
||||
cute::move(tC_gCol),
|
||||
cute::move(tC_rCol),
|
||||
cute::move(tC_cCol),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
296
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Normal file
296
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Normal file
@ -0,0 +1,296 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "cutlass_visitor_2x_broadcast_epilogue.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This defines a quantized GEMM operation with dequantized output, similar to
|
||||
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Arch, typename ElementAB_, typename ElementD_,
|
||||
typename TileShape, typename WarpShape, typename InstructionShape,
|
||||
int32_t MainLoopStages>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
float, cutlass::layout::RowMajor, 4,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
MainLoopStages, Operator,
|
||||
1 /* epilogue stages */
|
||||
>::GemmKernel;
|
||||
// clang-format on
|
||||
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const *>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const *>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
|
||||
|
||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||
|
||||
// If A and B are quantized per-tensor, then these scale tensors are scalars,
|
||||
// and they are passed in via the second argument.
|
||||
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
|
||||
ScaleAArgs a_args = a_scales.numel() == 1
|
||||
? ScaleAArgs{nullptr, a_scales.item<float>(), {}}
|
||||
: ScaleAArgs{a_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
|
||||
ScaleBArgs b_args = b_scales.numel() == 1
|
||||
? ScaleBArgs{nullptr, b_scales.item<float>(), {}}
|
||||
: ScaleBArgs{b_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
|
||||
evt0_compute_args};
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt1_compute_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
typename Gemm::Op::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||
problem_size, // problem size
|
||||
1, // batch count
|
||||
epilogue_args,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldc};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 2>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
|
||||
WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
|
||||
WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
}
|
240
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Normal file
240
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Normal file
@ -0,0 +1,240 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This defines a quantized GEMM operation with dequantized output, similar to
|
||||
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
||||
typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleBDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, float>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||
EpilogueSchedule, EVTCompute1>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, 16,
|
||||
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB *>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB *>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
|
||||
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
|
||||
ScaleA_Args a_args = a_scales.numel() == 1
|
||||
? ScaleA_Args{nullptr, a_scales.item<float>(), {}}
|
||||
: ScaleA_Args{a_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
ScaleB_Args b_args = b_scales.numel() == 1
|
||||
? ScaleB_Args{nullptr, b_scales.item<float>(), {}}
|
||||
: ScaleB_Args{b_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
args.epilogue.thread = {a_args, {b_args}};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
TORCH_CHECK(workspace_size == 0);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::bfloat16_t, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::half_t, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
}
|
65
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Normal file
65
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Normal file
@ -0,0 +1,65 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b, torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
int32_t major_capability;
|
||||
int32_t minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
||||
} else if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
||||
} else if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
192
tests/kernels/test_cutlass.py
Normal file
192
tests/kernels/test_cutlass.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Tests for cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
b = to_fp8(torch.randn((n, k), device=device).t())
|
||||
|
||||
m_a_scales = m if per_token_act_quant else 1
|
||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||
|
||||
m_a_scales = m if per_token_act_quant else 1
|
||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=out_dtype)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype]):
|
||||
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
|
||||
out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype]):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
|
||||
out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
|
||||
torch.bfloat16, device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
device: str):
|
||||
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
|
||||
torch.bfloat16, device)
|
||||
|
||||
|
||||
# For the following two tests:
|
||||
# N and K correspond to the size of the weight matrix and likely to be multiples
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
|
||||
whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
|
||||
a = whole_a[0:m, 0:k]
|
||||
b = whole_b[0:k, 0:n]
|
||||
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -163,6 +163,22 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
size_k)
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
||||
a_scales: torch.Tensor, b_scales: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user