[Kernel] Add needs_fixed_stride_order tag to most GEMMs (#14306)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
8ca2b21c98
commit
99b0915d3b
@ -4,6 +4,7 @@
|
|||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
|
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
#include <torch/version.h>
|
||||||
|
|
||||||
// Note on op signatures:
|
// Note on op signatures:
|
||||||
// The X_meta signatures are for the meta functions corresponding to op X.
|
// The X_meta signatures are for the meta functions corresponding to op X.
|
||||||
@ -17,6 +18,15 @@
|
|||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
// vLLM custom ops
|
// vLLM custom ops
|
||||||
|
//
|
||||||
|
|
||||||
|
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
|
||||||
|
// to override this for many GEMMs with the following tag. Otherwise,
|
||||||
|
// torch.compile will force all input tensors to be contiguous(), which
|
||||||
|
// will break many custom ops that require column-major weight matrices.
|
||||||
|
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
|
||||||
|
// to match exact eager-mode strides.
|
||||||
|
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||||
|
|
||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||||
@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
||||||
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
||||||
"-> Tensor");
|
"-> Tensor",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
||||||
|
|
||||||
// Decompression method for AQLM.
|
// Decompression method for AQLM.
|
||||||
ops.def(
|
ops.def(
|
||||||
"aqlm_dequant(Tensor codes, Tensor codebooks, "
|
"aqlm_dequant(Tensor codes, Tensor codebooks, "
|
||||||
"int[] codebook_partition_sizes) -> Tensor");
|
"int[] codebook_partition_sizes) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
||||||
|
|
||||||
// Quantized GEMM for AWQ.
|
// Quantized GEMM for AWQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||||
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
|
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||||
|
|
||||||
// Dequantization for AWQ.
|
// Dequantization for AWQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
||||||
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
|
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||||
|
|
||||||
// Note about marlin kernel 'workspace' arguments:
|
// Note about marlin kernel 'workspace' arguments:
|
||||||
@ -202,7 +216,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||||
"Tensor");
|
"Tensor",
|
||||||
|
{stride_tag});
|
||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||||
@ -210,7 +225,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||||
"Tensor b_scales, Tensor workspace, "
|
"Tensor b_scales, Tensor workspace, "
|
||||||
"int b_q_type, "
|
"int b_q_type, "
|
||||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||||
@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor? channel_scales,"
|
" Tensor? channel_scales,"
|
||||||
" Tensor? token_scales,"
|
" Tensor? token_scales,"
|
||||||
" str? schedule"
|
" str? schedule"
|
||||||
") -> Tensor");
|
") -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
ops.def(
|
ops.def(
|
||||||
"machete_prepack_B("
|
"machete_prepack_B("
|
||||||
" Tensor B,"
|
" Tensor B,"
|
||||||
@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||||
"int b_q_type, "
|
"int b_q_type, "
|
||||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||||
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// gptq_marlin repack from GPTQ.
|
// gptq_marlin repack from GPTQ.
|
||||||
@ -291,7 +309,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
||||||
"SymInt size_k) -> Tensor");
|
"SymInt size_k) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// marlin_qqq_gemm for QQQ.
|
// marlin_qqq_gemm for QQQ.
|
||||||
@ -299,14 +318,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||||
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||||
"SymInt size_k) -> Tensor");
|
"SymInt size_k) -> Tensor",
|
||||||
|
{stride_tag});
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// CUTLASS nvfp4 block scaled GEMM
|
// CUTLASS nvfp4 block scaled GEMM
|
||||||
ops.def(
|
ops.def(
|
||||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||||
" Tensor alpha) -> ()");
|
" Tensor alpha) -> ()",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
@ -314,7 +335,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||||
" Tensor b, Tensor a_scales,"
|
" Tensor b, Tensor a_scales,"
|
||||||
" Tensor b_scales, Tensor? bias) -> ()");
|
" Tensor b_scales, Tensor? bias) -> ()",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||||
@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||||
" Tensor b, Tensor a_scales,"
|
" Tensor b, Tensor a_scales,"
|
||||||
" Tensor b_scales, Tensor azp_adj,"
|
" Tensor b_scales, Tensor azp_adj,"
|
||||||
" Tensor? azp, Tensor? bias) -> ()");
|
" Tensor? azp, Tensor? bias) -> ()",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||||
|
|
||||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||||
@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||||
" Tensor bt_nzs,"
|
" Tensor bt_nzs,"
|
||||||
" Tensor bt_meta, Tensor a_scales,"
|
" Tensor bt_meta, Tensor a_scales,"
|
||||||
" Tensor b_scales, Tensor? bias) -> ()");
|
" Tensor b_scales, Tensor? bias) -> ()",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||||
|
|
||||||
// CUTLASS sparse matrix compressor
|
// CUTLASS sparse matrix compressor
|
||||||
@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
||||||
"-> Tensor");
|
"-> Tensor",
|
||||||
|
{stride_tag});
|
||||||
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||||
|
|
||||||
// Post processing for GPTQ.
|
// Post processing for GPTQ.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user