[Kernel] Tuned int8 kernels for Ada Lovelace (#6848)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-07-29 22:24:58 -04:00 committed by GitHub
parent 61a97c32f6
commit af647fb8b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 395 additions and 43 deletions

View File

@ -4,7 +4,8 @@
#include "scaled_mm_c2x.cuh" #include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh" #include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_dispatch.cuh" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/* /*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for This file defines quantized GEMM operations using the CUTLASS 2.x API, for
@ -98,25 +99,17 @@ template <template <typename, typename> typename Epilogue,
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (a.dtype() == torch::kInt8) { if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_caller< return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, Epilogue>(
int8_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
assert(out.dtype() == torch::kFloat16); assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_caller< return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, Epilogue>(
int8_t, cutlass::half_t, Epilogue, TileShape,
WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else { } else {
@ -124,13 +117,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t, return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass::bfloat16_t, Epilogue>( cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t, return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>( cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }

View File

@ -4,7 +4,7 @@
#include "cutlass/float8.h" #include "cutlass/float8.h"
/** /**
* This file defines Gemm kernel configurations for SM89 based on the Gemm * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape. * shape.
*/ */
@ -12,7 +12,7 @@ namespace vllm {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue> template <typename, typename> typename Epilogue>
struct sm89_fallback_gemm { struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes // Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
@ -25,7 +25,7 @@ struct sm89_fallback_gemm {
FP8MathOperator>; FP8MathOperator>;
}; };
struct sm89_config_default { struct sm89_fp8_config_default {
// M in (256, inf) // M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@ -40,7 +40,8 @@ struct sm89_config_default {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n); uint32_t const np2 = next_pow_2(n);
@ -74,7 +75,7 @@ struct sm89_config_default {
} }
}; };
struct sm89_config_M256 { struct sm89_fp8_config_M256 {
// M in (128, 256] // M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@ -89,7 +90,8 @@ struct sm89_config_M256 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n); uint32_t const np2 = next_pow_2(n);
@ -114,7 +116,7 @@ struct sm89_config_M256 {
} }
}; };
struct sm89_config_M128 { struct sm89_fp8_config_M128 {
// M in (64, 128] // M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@ -129,7 +131,8 @@ struct sm89_config_M128 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n); uint32_t const np2 = next_pow_2(n);
@ -163,7 +166,7 @@ struct sm89_config_M128 {
} }
}; };
struct sm89_config_M64 { struct sm89_fp8_config_M64 {
// M in (32, 64] // M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@ -176,7 +179,8 @@ struct sm89_config_M64 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n); uint32_t const np2 = next_pow_2(n);
@ -215,7 +219,7 @@ struct sm89_config_M64 {
} }
}; };
struct sm89_config_M32 { struct sm89_fp8_config_M32 {
// M in (16, 32] // M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
@ -229,7 +233,8 @@ struct sm89_config_M32 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n); uint32_t const np2 = next_pow_2(n);
@ -265,7 +270,7 @@ struct sm89_config_M32 {
} }
}; };
struct sm89_config_M16 { struct sm89_fp8_config_M16 {
// M in [1, 16] // M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@ -281,7 +286,8 @@ struct sm89_config_M16 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm = using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm; typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n); uint32_t const np2 = next_pow_2(n);
@ -320,10 +326,10 @@ struct sm89_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename> typename Epilogue, template <typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
@ -334,27 +340,27 @@ inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
if (mp2 <= 16) { if (mp2 <= 16) {
// M in [1, 16] // M in [1, 16]
return sm89_config_M16::dispatch<InType, OutType, Epilogue>( return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) { } else if (mp2 <= 32) {
// M in (16, 32] // M in (16, 32]
return sm89_config_M32::dispatch<InType, OutType, Epilogue>( return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) { } else if (mp2 <= 64) {
// M in (32, 64] // M in (32, 64]
return sm89_config_M64::dispatch<InType, OutType, Epilogue>( return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
// M in (64, 128] // M in (64, 128]
return sm89_config_M128::dispatch<InType, OutType, Epilogue>( return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) { } else if (mp2 <= 256) {
// M in (128, 256] // M in (128, 256]
return sm89_config_M256::dispatch<InType, OutType, Epilogue>( return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else { } else {
// M in (256, inf) // M in (256, inf)
return sm89_config_default::dispatch<InType, OutType, Epilogue>( return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} }
} }

View File

@ -0,0 +1,353 @@
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_int8_fallback_gemm {
// Shared mem requirement : 61440
static_assert(std::is_same<InType, int8_t>());
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
static int32_t const MainLoopStages = 5;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
struct sm89_int8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M128 {
// M in (64, 128]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm

View File

@ -119,8 +119,8 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])