[Kernel] Adding bias epilogue support for cutlass_scaled_mm (#5560)

Co-authored-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Luka Govedič 2024-06-26 11:16:00 -04:00 committed by GitHub
parent 6984c02a27
commit 5bfd1bbc98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 385 additions and 136 deletions

View File

@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21)
project(vllm_extensions LANGUAGES CXX)
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")

View File

@ -96,7 +96,8 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#endif

View File

@ -77,24 +77,12 @@ struct enable_sm89_to_sm90 : Kernel {
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
@ -102,6 +90,32 @@ struct ScaledEpilogue {
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
@ -134,6 +148,53 @@ struct ScaledEpilogue {
}
};
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
@ -168,13 +229,13 @@ struct cutlass_2x_gemm {
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
@ -404,14 +465,13 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
}
}
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
@ -420,78 +480,130 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
ScaledEpilogue>(out, a, b, a_scales,
b_scales);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
cutlass::float_e4m3_t, cutlass::half_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}

View File

@ -59,6 +59,28 @@ struct enable_sm90_or_later : Kernel {
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleBDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, float>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
@ -76,21 +98,13 @@ struct enable_sm90_or_later : Kernel {
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue {
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleBDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, float>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
@ -120,6 +134,54 @@ struct ScaledEpilogue {
}
};
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, ElementD,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, ElementD,
cutlass::FloatRoundStyle::round_to_nearest>;
using BiasDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, ElementD>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
using Bias_Args = typename Bias::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
return ArgumentType{a_args, {b_args}, bias_args};
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
@ -440,41 +502,56 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
}
}
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
ScaledEpilogue>(
out, a, b, a_scales, b_scales);
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,
ScaledEpilogue>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
c, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
b_scales);
}
}
#endif

View File

@ -6,23 +6,27 @@
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
@ -43,7 +47,8 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
int32_t major_capability;
int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
@ -66,6 +71,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
if (version_num >= 90) {
@ -73,19 +83,19 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
}
}
}

View File

@ -142,7 +142,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// Check if cutlass scaled_mm is supported for CUDA devices of the given

View File

@ -32,6 +32,7 @@ def cutlass_fp8_gemm_helper(m: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
@ -46,10 +47,17 @@ def cutlass_fp8_gemm_helper(m: int,
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(out_dtype)
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
@ -59,6 +67,7 @@ def cutlass_int8_gemm_helper(m: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
@ -74,11 +83,17 @@ def cutlass_int8_gemm_helper(m: int,
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=out_dtype)
if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(dtype=out_dtype)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
@ -87,11 +102,12 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch)
per_out_ch: bool, bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
@pytest.mark.parametrize("m", [512, 222, 33, 1])
@ -99,49 +115,72 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch)
per_out_ch: bool, bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype]):
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
out_dtype)
out_dtype: Type[torch.dtype],
bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype]):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
out_dtype)
out_dtype: Type[torch.dtype],
bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
torch.bfloat16, device)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
device: str):
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
torch.bfloat16, device)
bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
out_dtype=torch.bfloat16,
device=device)
# For the following two tests:
@ -151,20 +190,25 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool):
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch)
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
bias)
# Test working with a subset of A and B

View File

@ -220,9 +220,12 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
@ -230,7 +233,8 @@ def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out