[Kernel] Add needs_fixed_stride_order tag to most GEMMs (#14306)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
8ca2b21c98
commit
99b0915d3b
@ -4,6 +4,7 @@
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <torch/version.h>
|
||||
|
||||
// Note on op signatures:
|
||||
// The X_meta signatures are for the meta functions corresponding to op X.
|
||||
@ -17,6 +18,15 @@
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
//
|
||||
|
||||
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
|
||||
// to override this for many GEMMs with the following tag. Otherwise,
|
||||
// torch.compile will force all input tensors to be contiguous(), which
|
||||
// will break many custom ops that require column-major weight matrices.
|
||||
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
|
||||
// to match exact eager-mode strides.
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
||||
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
||||
"-> Tensor");
|
||||
"-> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
||||
|
||||
// Decompression method for AQLM.
|
||||
ops.def(
|
||||
"aqlm_dequant(Tensor codes, Tensor codebooks, "
|
||||
"int[] codebook_partition_sizes) -> Tensor");
|
||||
"int[] codebook_partition_sizes) -> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
||||
|
||||
// Quantized GEMM for AWQ.
|
||||
ops.def(
|
||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
|
||||
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||
|
||||
// Dequantization for AWQ.
|
||||
ops.def(
|
||||
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
|
||||
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||
|
||||
// Note about marlin kernel 'workspace' arguments:
|
||||
@ -202,7 +216,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||
"Tensor");
|
||||
"Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
@ -210,7 +225,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
"Tensor b_scales, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor? channel_scales,"
|
||||
" Tensor? token_scales,"
|
||||
" str? schedule"
|
||||
") -> Tensor");
|
||||
") -> Tensor",
|
||||
{stride_tag});
|
||||
ops.def(
|
||||
"machete_prepack_B("
|
||||
" Tensor B,"
|
||||
@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// gptq_marlin repack from GPTQ.
|
||||
@ -291,7 +309,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor");
|
||||
"SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
@ -299,14 +318,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor");
|
||||
"SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
" Tensor alpha) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
@ -314,7 +335,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||
" Tensor bt_nzs,"
|
||||
" Tensor bt_meta, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||
|
||||
// CUTLASS sparse matrix compressor
|
||||
@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
||||
"-> Tensor");
|
||||
"-> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
// Post processing for GPTQ.
|
||||
|
Loading…
x
Reference in New Issue
Block a user