[Kernel] Tuned int8 kernels for Ada Lovelace (#6848)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
61a97c32f6
commit
af647fb8b3
@ -4,7 +4,8 @@
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
@ -98,25 +99,17 @@ template <template <typename, typename> typename Epilogue,
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
int8_t, cutlass::bfloat16_t, Epilogue,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
int8_t, cutlass::half_t, Epilogue, TileShape,
|
||||
WarpShape, InstructionShape, 5>>(
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
@ -124,12 +117,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 based on the Gemm
|
||||
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
@ -12,7 +12,7 @@ namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_fallback_gemm {
|
||||
struct sm89_fp8_fallback_gemm {
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
@ -25,7 +25,7 @@ struct sm89_fallback_gemm {
|
||||
FP8MathOperator>;
|
||||
};
|
||||
|
||||
struct sm89_config_default {
|
||||
struct sm89_fp8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
@ -40,7 +40,8 @@ struct sm89_config_default {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
@ -74,7 +75,7 @@ struct sm89_config_default {
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_config_M256 {
|
||||
struct sm89_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
@ -89,7 +90,8 @@ struct sm89_config_M256 {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
@ -114,7 +116,7 @@ struct sm89_config_M256 {
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_config_M128 {
|
||||
struct sm89_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
@ -129,7 +131,8 @@ struct sm89_config_M128 {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
@ -163,7 +166,7 @@ struct sm89_config_M128 {
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_config_M64 {
|
||||
struct sm89_fp8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
@ -176,7 +179,8 @@ struct sm89_config_M64 {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
@ -215,7 +219,7 @@ struct sm89_config_M64 {
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_config_M32 {
|
||||
struct sm89_fp8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
@ -229,7 +233,8 @@ struct sm89_config_M32 {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
@ -265,7 +270,7 @@ struct sm89_config_M32 {
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_config_M16 {
|
||||
struct sm89_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
@ -281,7 +286,8 @@ struct sm89_config_M16 {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
@ -320,7 +326,7 @@ struct sm89_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
@ -334,27 +340,27 @@ inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
@ -0,0 +1,353 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_int8_fallback_gemm {
|
||||
// Shared mem requirement : 61440
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
static int32_t const MainLoopStages = 5;
|
||||
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
struct sm89_int8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -119,8 +119,8 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user