[Kernel] Add needs_fixed_stride_order tag to most GEMMs (#14306)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-03-06 17:17:09 -05:00 committed by GitHub
parent 8ca2b21c98
commit 99b0915d3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@
#include "core/registration.h"
#include <torch/library.h>
#include <torch/version.h>
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
@ -17,6 +18,15 @@
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
// to override this for many GEMMs with the following tag. Otherwise,
// torch.compile will force all input tensors to be contiguous(), which
// will break many custom ops that require column-major weight matrices.
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
// to match exact eager-mode strides.
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
"-> Tensor");
"-> Tensor",
{stride_tag});
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM.
ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, "
"int[] codebook_partition_sizes) -> Tensor");
"int[] codebook_partition_sizes) -> Tensor",
{stride_tag});
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
{stride_tag});
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
{stride_tag});
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Note about marlin kernel 'workspace' arguments:
@ -202,7 +216,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
"Tensor");
"Tensor",
{stride_tag});
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
@ -210,7 +225,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? channel_scales,"
" Tensor? token_scales,"
" str? schedule"
") -> Tensor");
") -> Tensor",
{stride_tag});
ops.def(
"machete_prepack_B("
" Tensor B,"
@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
@ -291,7 +309,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor");
"SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
@ -299,14 +318,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor");
"SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
" Tensor alpha) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@ -314,7 +335,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
" Tensor b_scales, Tensor? bias) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
" Tensor? azp, Tensor? bias) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
" Tensor b_scales, Tensor? bias) -> ()",
{stride_tag});
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor
@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"-> Tensor");
"-> Tensor",
{stride_tag});
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.