[Kernel] Add needs_fixed_stride_order tag to most GEMMs (#14306)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-03-06 17:17:09 -05:00 committed by GitHub
parent 8ca2b21c98
commit 99b0915d3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@
#include "core/registration.h" #include "core/registration.h"
#include <torch/library.h> #include <torch/library.h>
#include <torch/version.h>
// Note on op signatures: // Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X. // The X_meta signatures are for the meta functions corresponding to op X.
@ -17,6 +18,15 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
//
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
// to override this for many GEMMs with the following tag. Otherwise,
// torch.compile will force all input tensors to be contiguous(), which
// will break many custom ops that require column-major weight matrices.
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
// to match exact eager-mode strides.
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, " "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) " "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
"-> Tensor"); "-> Tensor",
{stride_tag});
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM. // Decompression method for AQLM.
ops.def( ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, " "aqlm_dequant(Tensor codes, Tensor codebooks, "
"int[] codebook_partition_sizes) -> Tensor"); "int[] codebook_partition_sizes) -> Tensor",
{stride_tag});
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ. // Quantized GEMM for AWQ.
ops.def( ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters) -> Tensor"); "Tensor _zeros, SymInt split_k_iters) -> Tensor",
{stride_tag});
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ. // Dequantization for AWQ.
ops.def( ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
{stride_tag});
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Note about marlin kernel 'workspace' arguments: // Note about marlin kernel 'workspace' arguments:
@ -202,7 +216,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
"Tensor"); "Tensor",
{stride_tag});
// conditionally compiled so impl in source file // conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
@ -210,7 +225,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, " "Tensor b_scales, Tensor workspace, "
"int b_q_type, " "int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor"); "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl in source file // conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper. // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? channel_scales," " Tensor? channel_scales,"
" Tensor? token_scales," " Tensor? token_scales,"
" str? schedule" " str? schedule"
") -> Tensor"); ") -> Tensor",
{stride_tag});
ops.def( ops.def(
"machete_prepack_B(" "machete_prepack_B("
" Tensor B," " Tensor B,"
@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, " "int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ. // gptq_marlin repack from GPTQ.
@ -291,7 +309,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, " "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor"); "SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ. // marlin_qqq_gemm for QQQ.
@ -299,14 +318,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, " "Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, SymInt size_m, SymInt size_n, " "Tensor! workspace, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor"); "SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
// CUTLASS nvfp4 block scaled GEMM // CUTLASS nvfp4 block scaled GEMM
ops.def( ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b," " Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()"); " Tensor alpha) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@ -314,7 +335,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a," "cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales," " Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"); " Tensor b_scales, Tensor? bias) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm); ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_mm_azp(Tensor! out, Tensor a," "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales," " Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj," " Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"); " Tensor? azp, Tensor? bias) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp); ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
// Check if cutlass scaled_mm is supported for CUDA devices of the given // Check if cutlass scaled_mm is supported for CUDA devices of the given
@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a," "cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs," " Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales," " Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"); " Tensor b_scales, Tensor? bias) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm); ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor // CUTLASS sparse matrix compressor
@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) " "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"-> Tensor"); "-> Tensor",
{stride_tag});
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ. // Post processing for GPTQ.