From 99b0915d3bcce5beeb812407b354179ee3092b4d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 6 Mar 2025 17:17:09 -0500 Subject: [PATCH] [Kernel] Add needs_fixed_stride_order tag to most GEMMs (#14306) Signed-off-by: Tyler Michael Smith --- csrc/torch_bindings.cpp | 55 ++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0b0334f8..fe7a674b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -4,6 +4,7 @@ #include "core/registration.h" #include +#include // Note on op signatures: // The X_meta signatures are for the meta functions corresponding to op X. @@ -17,6 +18,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops + // + + // The default behavior in PyTorch 2.6 is "requires_contiguous", so we need + // to override this for many GEMMs with the following tag. Otherwise, + // torch.compile will force all input tensors to be contiguous(), which + // will break many custom ops that require column-major weight matrices. + // TODO: remove this for PyTorch 2.8, when the default is planned to switch + // to match exact eager-mode strides. + at::Tag stride_tag = at::Tag::needs_fixed_stride_order; ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); @@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, " "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) " - "-> Tensor"); + "-> Tensor", + {stride_tag}); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); // Decompression method for AQLM. ops.def( "aqlm_dequant(Tensor codes, Tensor codebooks, " - "int[] codebook_partition_sizes) -> Tensor"); + "int[] codebook_partition_sizes) -> Tensor", + {stride_tag}); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); // Quantized GEMM for AWQ. ops.def( "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters) -> Tensor"); + "Tensor _zeros, SymInt split_k_iters) -> Tensor", + {stride_tag}); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Dequantization for AWQ. ops.def( "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); + "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor", + {stride_tag}); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // Note about marlin kernel 'workspace' arguments: @@ -202,7 +216,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " - "Tensor"); + "Tensor", + {stride_tag}); // conditionally compiled so impl in source file // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. @@ -210,7 +225,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " "Tensor b_scales, Tensor workspace, " "int b_q_type, " - "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor"); + "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor", + {stride_tag}); // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. @@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? channel_scales," " Tensor? token_scales," " str? schedule" - ") -> Tensor"); + ") -> Tensor", + {stride_tag}); ops.def( "machete_prepack_B(" " Tensor B," @@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " "int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " - "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor", + {stride_tag}); // conditionally compiled so impl registration is in source file // gptq_marlin repack from GPTQ. @@ -291,7 +309,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor"); + "SymInt size_k) -> Tensor", + {stride_tag}); // conditionally compiled so impl registration is in source file // marlin_qqq_gemm for QQQ. @@ -299,14 +318,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " "Tensor s_tok, Tensor s_ch, Tensor s_group, " "Tensor! workspace, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor"); + "SymInt size_k) -> Tensor", + {stride_tag}); // conditionally compiled so impl registration is in source file // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," " Tensor block_scale_a, Tensor block_scale_b," - " Tensor alpha) -> ()"); + " Tensor alpha) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column @@ -314,7 +335,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cutlass_scaled_mm(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); + " Tensor b_scales, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm); // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column @@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_scaled_mm_azp(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); + " Tensor? azp, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp); // Check if cutlass scaled_mm is supported for CUDA devices of the given @@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_scaled_sparse_mm(Tensor! out, Tensor a," " Tensor bt_nzs," " Tensor bt_meta, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); + " Tensor b_scales, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm); // CUTLASS sparse matrix compressor @@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) " - "-> Tensor"); + "-> Tensor", + {stride_tag}); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); // Post processing for GPTQ.