[Kernel] Factor out epilogues from cutlass kernels (#5391)
Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: zifeitong <zifei.tong@parasail.io> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
parent
0ce7b952f8
commit
85657b5607
@ -179,9 +179,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
|
||||
#
|
||||
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
||||
@ -189,7 +189,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||
set_source_files_properties(
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
|
@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
||||
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
return ops.cutlass_scaled_mm_dq(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype)
|
||||
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
||||
|
||||
|
||||
# bench
|
||||
|
@ -90,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
#endif
|
||||
|
||||
|
@ -29,21 +29,14 @@
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This defines a quantized GEMM operation with dequantized output, similar to
|
||||
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
Epilogue functions can be defined to post-process the output before it is
|
||||
written to GPU memory.
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_, typename TileShape,
|
||||
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch._scaled_mm.
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue {
|
||||
private:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
||||
return evt_compute_args;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||
|
||||
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
|
||||
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
|
||||
evt0_compute_args};
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
using Epilogue = typename Gemm::Epilogue;
|
||||
auto evt_args =
|
||||
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt1_compute_args,
|
||||
evt_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
}
|
@ -32,21 +32,14 @@
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This defines a quantized GEMM operation with dequantized output, similar to
|
||||
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
Epilogue functions can be defined to post-process the output before it is
|
||||
written to GPU memory.
|
||||
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
@ -71,21 +64,25 @@ struct enable_sm90_or_later : Kernel {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
||||
typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch.scaled_mm_.
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
A and B may be both either int8 or fp8_e4m3. A can be
|
||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogue {
|
||||
private:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
@ -111,19 +108,53 @@ struct cutlass_3x_gemm {
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleA_Args = typename ScaleA::Arguments;
|
||||
using ScaleB_Args = typename ScaleB::Arguments;
|
||||
|
||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
|
||||
return ArgumentType{a_args, {b_args}};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||
EpilogueSchedule, EVTCompute1>::CollectiveOp;
|
||||
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
@ -148,11 +179,10 @@ struct cutlass_3x_gemm {
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
@ -182,19 +212,13 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
|
||||
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
|
||||
|
||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
|
||||
args.epilogue.thread = {a_args, {b_args}};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
@ -209,7 +233,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType, int32_t M>
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue, int32_t M>
|
||||
struct sm90_fp8_config {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
@ -219,12 +244,13 @@ struct sm90_fp8_config {
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
struct sm90_fp8_config<InType, OutType, 128> {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
@ -233,12 +259,13 @@ struct sm90_fp8_config<InType, OutType, 128> {
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
struct sm90_fp8_config<InType, OutType, 64> {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
@ -247,30 +274,28 @@ struct sm90_fp8_config<InType, OutType, 64> {
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
|
||||
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
|
||||
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
|
||||
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@ -278,23 +303,23 @@ void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
@ -308,16 +333,15 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
return cutlass_gemm_caller<cutlass_3x_gemm<
|
||||
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(
|
||||
return cutlass_gemm_caller<
|
||||
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
} else {
|
||||
@ -325,13 +349,13 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t>(
|
||||
return cutlass_gemm_sm90_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t>(
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
@ -3,31 +3,31 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
int32_t major_capability;
|
||||
int32_t minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
@ -57,19 +57,19 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
|
||||
#else
|
||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
|
||||
} else if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
@ -136,10 +136,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_dq(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq);
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
|
@ -47,7 +47,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
|
||||
|
||||
@ -74,7 +74,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=out_dtype)
|
||||
@ -180,11 +180,11 @@ def test_cutlass_subset():
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out = ops.cutlass_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
||||
@ -203,8 +203,8 @@ class CutlassLayer(torch.nn.Module):
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
|
@ -212,9 +212,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
||||
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
|
||||
@ -222,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
||||
n = b.shape[1]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
|
||||
|
||||
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -81,5 +81,5 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
x_q, input_scales = custom_ops.scaled_int8_quant(x)
|
||||
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
|
||||
weight_scale, x.dtype)
|
||||
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales,
|
||||
weight_scale, x.dtype)
|
||||
|
@ -99,5 +99,5 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
# Input quantize
|
||||
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
|
||||
|
||||
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
|
||||
weight_scale, x.dtype)
|
||||
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
|
||||
weight_scale, x.dtype)
|
||||
|
@ -261,7 +261,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm_dq(
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user