2024-06-09 16:23:30 -04:00
|
|
|
#include "cache.h"
|
|
|
|
#include "cuda_utils.h"
|
|
|
|
#include "ops.h"
|
2024-08-02 16:51:58 -04:00
|
|
|
#include "core/registration.h"
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
#include <torch/library.h>
|
2025-03-06 17:17:09 -05:00
|
|
|
#include <torch/version.h>
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
// Note on op signatures:
|
|
|
|
// The X_meta signatures are for the meta functions corresponding to op X.
|
|
|
|
// They must be kept in sync with the signature for X. Generally, only
|
|
|
|
// functions that return Tensors require a meta function.
|
|
|
|
//
|
|
|
|
// See the following links for detailed docs on op registration and function
|
|
|
|
// schemas.
|
|
|
|
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
|
|
|
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
|
|
// vLLM custom ops
|
2025-03-06 17:17:09 -05:00
|
|
|
//
|
|
|
|
|
|
|
|
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
|
|
|
|
// to override this for many GEMMs with the following tag. Otherwise,
|
|
|
|
// torch.compile will force all input tensors to be contiguous(), which
|
|
|
|
// will break many custom ops that require column-major weight matrices.
|
|
|
|
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
|
|
|
|
// to match exact eager-mode strides.
|
|
|
|
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
2024-06-09 16:23:30 -04:00
|
|
|
|
2024-10-27 00:19:28 -07:00
|
|
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
|
|
|
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Attention ops
|
|
|
|
// Compute the attention between an input query and the cached
|
|
|
|
// keys/values using PagedAttention.
|
|
|
|
ops.def(
|
|
|
|
"paged_attention_v1("
|
|
|
|
" Tensor! out, Tensor query, Tensor key_cache,"
|
|
|
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
|
|
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
|
|
|
" int max_seq_len, Tensor? alibi_slopes,"
|
2025-01-23 13:04:03 -05:00
|
|
|
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
2024-07-16 18:31:32 -04:00
|
|
|
" int tp_rank, int blocksparse_local_blocks,"
|
2024-06-09 16:23:30 -04:00
|
|
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
|
|
|
" int blocksparse_head_sliding_step) -> ()");
|
|
|
|
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
|
|
|
|
|
|
|
// PagedAttention V2.
|
|
|
|
ops.def(
|
|
|
|
"paged_attention_v2("
|
2024-09-11 15:52:19 -04:00
|
|
|
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
|
|
|
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
2024-06-09 16:23:30 -04:00
|
|
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
|
|
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
|
|
|
" int max_seq_len, Tensor? alibi_slopes,"
|
2025-01-23 13:04:03 -05:00
|
|
|
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
2024-07-16 18:31:32 -04:00
|
|
|
" int tp_rank, int blocksparse_local_blocks,"
|
2024-06-09 16:23:30 -04:00
|
|
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
|
|
|
" int blocksparse_head_sliding_step) -> ()");
|
|
|
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
|
|
|
|
|
|
|
// Activation ops
|
|
|
|
// Activation function used in SwiGLU.
|
|
|
|
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
|
|
|
|
2025-01-15 10:29:53 +08:00
|
|
|
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Activation function used in GeGLU with `none` approximation.
|
|
|
|
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
|
|
|
|
|
|
|
// Activation function used in GeGLU with `tanh` approximation.
|
|
|
|
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
|
|
|
|
2024-10-24 16:18:27 +08:00
|
|
|
// FATReLU implementation.
|
|
|
|
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
|
|
|
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// GELU implementation used in GPT-2.
|
|
|
|
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
|
|
|
|
|
|
|
// Approximate GELU implementation.
|
|
|
|
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
|
|
|
|
|
2024-06-20 04:52:09 -07:00
|
|
|
// Quick GELU implementation.
|
|
|
|
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
|
|
|
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
|
|
|
|
2024-07-17 17:30:28 -04:00
|
|
|
// prepare_inputs advance_step
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
2024-09-12 11:16:22 -07:00
|
|
|
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
2024-09-11 15:52:19 -04:00
|
|
|
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
|
|
|
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
|
|
|
"Tensor block_tables) -> ()");
|
2024-09-12 11:16:22 -07:00
|
|
|
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
|
|
|
|
|
|
|
ops.def(
|
|
|
|
"advance_step_flashinfer("
|
|
|
|
" int num_seqs, int num_queries, int block_size,"
|
|
|
|
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
|
|
|
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
|
|
|
" Tensor block_tables, Tensor! paged_kv_indices,"
|
|
|
|
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
|
|
|
" Tensor! block_table_bounds"
|
|
|
|
") -> ()");
|
|
|
|
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
2024-07-17 17:30:28 -04:00
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Layernorm
|
|
|
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
|
|
|
ops.def(
|
2024-11-08 16:20:08 -05:00
|
|
|
"rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> "
|
2024-06-09 16:23:30 -04:00
|
|
|
"()");
|
|
|
|
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
|
|
|
|
|
|
|
// In-place fused Add and RMS Normalization.
|
|
|
|
ops.def(
|
|
|
|
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
|
|
|
|
"float epsilon) -> ()");
|
|
|
|
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
|
|
|
|
2024-11-08 16:20:08 -05:00
|
|
|
// Layernorm-quant
|
|
|
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
|
|
|
ops.def(
|
|
|
|
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
|
|
|
|
"Tensor scale, float epsilon) -> "
|
|
|
|
"()");
|
|
|
|
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
|
|
|
|
&rms_norm_static_fp8_quant);
|
|
|
|
|
|
|
|
// In-place fused Add and RMS Normalization.
|
|
|
|
ops.def(
|
|
|
|
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
|
|
|
|
"Tensor! residual, Tensor weight, "
|
|
|
|
"Tensor scale, float epsilon) -> ()");
|
|
|
|
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
|
|
|
|
&fused_add_rms_norm_static_fp8_quant);
|
|
|
|
|
2024-12-12 22:19:23 -05:00
|
|
|
// Fused Layernorm + Quant kernels
|
|
|
|
ops.def(
|
|
|
|
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
|
|
|
|
"Tensor weight, Tensor! scale, float epsilon, "
|
|
|
|
"Tensor? scale_ub, Tensor!? residual) -> ()");
|
|
|
|
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
|
|
|
|
&rms_norm_dynamic_per_token_quant);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Rotary embedding
|
|
|
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
|
|
|
ops.def(
|
|
|
|
"rotary_embedding(Tensor positions, Tensor! query,"
|
|
|
|
" Tensor! key, int head_size,"
|
|
|
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
|
|
|
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
|
|
|
|
|
|
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
|
|
|
|
// (supports multiple loras).
|
|
|
|
ops.def(
|
|
|
|
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
|
|
|
" Tensor! key, int head_size,"
|
|
|
|
" Tensor cos_sin_cache, bool is_neox,"
|
|
|
|
" int rot_dim,"
|
|
|
|
" Tensor cos_sin_cache_offsets) -> ()");
|
|
|
|
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
|
|
|
|
|
|
|
|
// Quantization ops
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
// Quantized GEMM for AQLM.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
|
|
|
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
2025-03-06 17:17:09 -05:00
|
|
|
"-> Tensor",
|
|
|
|
{stride_tag});
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
|
|
|
|
|
|
|
// Decompression method for AQLM.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"aqlm_dequant(Tensor codes, Tensor codebooks, "
|
2025-03-06 17:17:09 -05:00
|
|
|
"int[] codebook_partition_sizes) -> Tensor",
|
|
|
|
{stride_tag});
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
|
|
|
|
|
|
|
// Quantized GEMM for AWQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
2025-03-06 17:17:09 -05:00
|
|
|
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
|
|
|
|
{stride_tag});
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
|
|
|
|
|
|
|
// Dequantization for AWQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
2025-03-06 17:17:09 -05:00
|
|
|
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
|
|
|
|
{stride_tag});
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
|
|
|
|
2024-09-11 15:52:19 -04:00
|
|
|
// Note about marlin kernel 'workspace' arguments:
|
|
|
|
// Technically these should be mutable since they are modified by the kernel.
|
|
|
|
// But since they are set back to zero once the kernel is finished we can
|
|
|
|
// hand wave and say that they have no net effect.
|
|
|
|
//
|
|
|
|
// The reason to mark 'workspace' as immutable is so that they don't interfere
|
|
|
|
// with using ScalarType arguments in the ops. If they are marked as mutable,
|
|
|
|
// pytorch throws an assert in
|
|
|
|
// 'torch._higher_order_ops._register_effectful_op' that prevents these
|
|
|
|
// kernels from being torch.compile'd.
|
|
|
|
// See the following document for more info on custom types and ops that use
|
|
|
|
// custom types:
|
|
|
|
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
2024-10-17 15:08:34 -04:00
|
|
|
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
2025-03-06 17:17:09 -05:00
|
|
|
"Tensor",
|
|
|
|
{stride_tag});
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl in source file
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
|
|
|
"Tensor b_scales, Tensor workspace, "
|
2024-10-17 15:08:34 -04:00
|
|
|
"int b_q_type, "
|
2025-03-06 17:17:09 -05:00
|
|
|
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
|
|
|
|
{stride_tag});
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl in source file
|
2024-06-09 16:23:30 -04:00
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
2024-10-03 22:55:25 -04:00
|
|
|
ops.def(
|
2024-11-18 14:59:29 -05:00
|
|
|
"machete_supported_schedules("
|
|
|
|
" ScalarType a_type,"
|
|
|
|
" int b_type,"
|
|
|
|
" ScalarType? maybe_group_scales_type,"
|
|
|
|
" ScalarType? maybe_group_zeros_type,"
|
|
|
|
" ScalarType? maybe_channel_scales_type,"
|
|
|
|
" ScalarType? maybe_token_scales_type,"
|
|
|
|
" ScalarType? maybe_out_type"
|
|
|
|
") -> str[]");
|
|
|
|
ops.def(
|
|
|
|
"machete_mm("
|
|
|
|
" Tensor A,"
|
|
|
|
" Tensor B,"
|
|
|
|
" int b_type,"
|
|
|
|
" ScalarType? out_type,"
|
|
|
|
" Tensor? group_scales,"
|
|
|
|
" Tensor? group_zeros,"
|
|
|
|
" int? group_size,"
|
|
|
|
" Tensor? channel_scales,"
|
|
|
|
" Tensor? token_scales,"
|
|
|
|
" str? schedule"
|
2025-03-06 17:17:09 -05:00
|
|
|
") -> Tensor",
|
|
|
|
{stride_tag});
|
2024-11-18 14:59:29 -05:00
|
|
|
ops.def(
|
|
|
|
"machete_prepack_B("
|
|
|
|
" Tensor B,"
|
|
|
|
" ScalarType a_type,"
|
|
|
|
" int b_type,"
|
|
|
|
" ScalarType? group_scales_type"
|
|
|
|
") -> Tensor");
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl registration is in source file
|
2024-08-20 09:09:33 -04:00
|
|
|
|
2024-09-23 13:46:26 -04:00
|
|
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
|
|
|
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
|
|
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
2024-10-17 15:08:34 -04:00
|
|
|
"int b_q_type, "
|
|
|
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
2025-03-08 00:53:38 +08:00
|
|
|
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
|
|
|
|
"bool is_zp_float) -> Tensor",
|
2025-03-06 17:17:09 -05:00
|
|
|
{stride_tag});
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl registration is in source file
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
// gptq_marlin repack from GPTQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
|
|
|
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl registrations are in source file
|
2024-06-09 16:23:30 -04:00
|
|
|
|
2024-07-21 19:41:42 -04:00
|
|
|
// awq_marlin repack from AWQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
|
|
|
"SymInt size_n, int num_bits) -> Tensor");
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl registrations are in source file
|
2024-11-23 13:14:49 +08:00
|
|
|
#endif
|
2024-07-21 19:41:42 -04:00
|
|
|
|
2024-08-06 07:54:23 +08:00
|
|
|
// Dequantization for GGML.
|
2024-10-17 15:08:34 -04:00
|
|
|
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
|
2024-08-06 07:54:23 +08:00
|
|
|
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
|
|
|
|
|
|
|
// mmvq kernel for GGML.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
2024-10-17 15:08:34 -04:00
|
|
|
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
|
2024-09-11 15:52:19 -04:00
|
|
|
"-> Tensor");
|
2024-08-06 07:54:23 +08:00
|
|
|
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
|
|
|
|
|
|
|
// mmq kernel for GGML.
|
2024-10-17 15:08:34 -04:00
|
|
|
ops.def(
|
|
|
|
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
|
2024-08-06 07:54:23 +08:00
|
|
|
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
|
|
|
|
2025-03-12 04:33:27 +01:00
|
|
|
// moe kernel for GGML.
|
|
|
|
ops.def(
|
|
|
|
"ggml_moe_a8(Tensor X, Tensor W, "
|
|
|
|
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
|
|
|
|
"num_tokens_post_padded, "
|
|
|
|
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
|
|
|
|
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
|
|
|
|
|
|
|
|
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
|
|
|
|
2024-11-23 13:14:49 +08:00
|
|
|
#ifndef USE_ROCM
|
2024-07-03 13:38:00 -04:00
|
|
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
2024-10-17 15:08:34 -04:00
|
|
|
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
2025-03-06 17:17:09 -05:00
|
|
|
"SymInt size_k) -> Tensor",
|
|
|
|
{stride_tag});
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl registration is in source file
|
2024-07-03 13:38:00 -04:00
|
|
|
|
2024-07-31 21:55:21 +08:00
|
|
|
// marlin_qqq_gemm for QQQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def(
|
|
|
|
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
|
|
|
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
2024-10-17 15:08:34 -04:00
|
|
|
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
2025-03-06 17:17:09 -05:00
|
|
|
"SymInt size_k) -> Tensor",
|
|
|
|
{stride_tag});
|
2024-10-03 22:55:25 -04:00
|
|
|
// conditionally compiled so impl registration is in source file
|
2024-07-31 21:55:21 +08:00
|
|
|
|
2025-02-22 05:24:05 -08:00
|
|
|
// CUTLASS nvfp4 block scaled GEMM
|
|
|
|
ops.def(
|
|
|
|
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
|
|
|
" Tensor block_scale_a, Tensor block_scale_b,"
|
2025-03-06 17:17:09 -05:00
|
|
|
" Tensor alpha) -> ()",
|
|
|
|
{stride_tag});
|
2025-02-22 05:24:05 -08:00
|
|
|
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
2024-08-06 14:17:08 -04:00
|
|
|
// quantization, as well as bias
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.def(
|
2024-06-13 14:22:19 -04:00
|
|
|
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
|
|
|
" Tensor b, Tensor a_scales,"
|
2025-03-06 17:17:09 -05:00
|
|
|
" Tensor b_scales, Tensor? bias) -> ()",
|
|
|
|
{stride_tag});
|
2024-06-13 14:22:19 -04:00
|
|
|
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
2024-06-20 14:36:10 -04:00
|
|
|
|
2024-08-06 14:17:08 -04:00
|
|
|
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
|
|
|
// quantization.
|
|
|
|
ops.def(
|
|
|
|
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
|
|
|
" Tensor b, Tensor a_scales,"
|
|
|
|
" Tensor b_scales, Tensor azp_adj,"
|
2025-03-06 17:17:09 -05:00
|
|
|
" Tensor? azp, Tensor? bias) -> ()",
|
|
|
|
{stride_tag});
|
2024-08-06 14:17:08 -04:00
|
|
|
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
|
|
|
|
2024-06-20 14:36:10 -04:00
|
|
|
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
|
|
|
// capability
|
2024-09-11 15:52:19 -04:00
|
|
|
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
|
|
|
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
|
|
|
|
2025-01-31 18:29:11 -05:00
|
|
|
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
|
|
|
ops.def(
|
|
|
|
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
|
|
|
"bool");
|
|
|
|
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
|
|
|
&cutlass_scaled_mm_supports_fp8);
|
|
|
|
|
2024-12-18 21:43:30 -05:00
|
|
|
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
|
|
|
// given capability
|
|
|
|
ops.def(
|
|
|
|
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
|
|
|
|
ops.impl("cutlass_sparse_scaled_mm_supported",
|
|
|
|
&cutlass_sparse_scaled_mm_supported);
|
|
|
|
|
2024-12-18 09:57:16 -05:00
|
|
|
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
|
|
|
|
// quantization, as well as bias
|
|
|
|
ops.def(
|
|
|
|
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
|
|
|
" Tensor bt_nzs,"
|
|
|
|
" Tensor bt_meta, Tensor a_scales,"
|
2025-03-06 17:17:09 -05:00
|
|
|
" Tensor b_scales, Tensor? bias) -> ()",
|
|
|
|
{stride_tag});
|
2024-12-18 09:57:16 -05:00
|
|
|
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
|
|
|
|
|
|
|
// CUTLASS sparse matrix compressor
|
2025-02-13 19:01:14 -05:00
|
|
|
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
|
|
|
|
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
|
2024-12-18 09:57:16 -05:00
|
|
|
|
2024-08-29 01:06:52 +03:00
|
|
|
// Mamba selective scan kernel
|
|
|
|
ops.def(
|
|
|
|
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
|
|
|
"Tensor! A, Tensor! B, Tensor! C,"
|
2024-09-30 00:35:58 +03:00
|
|
|
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
2024-08-29 01:06:52 +03:00
|
|
|
"bool delta_softplus,"
|
2024-09-30 00:35:58 +03:00
|
|
|
"Tensor? query_start_loc,"
|
|
|
|
"Tensor? cache_indices,"
|
|
|
|
"Tensor? has_initial_state,"
|
2024-10-17 00:12:43 +08:00
|
|
|
"Tensor! ssm_states,"
|
|
|
|
"int pad_slot_id) -> ()");
|
2024-08-29 01:06:52 +03:00
|
|
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
|
|
|
|
|
|
|
ops.def(
|
|
|
|
"causal_conv1d_update(Tensor! x,"
|
|
|
|
"Tensor! conv_state,"
|
|
|
|
"Tensor! weight,"
|
2024-09-30 00:35:58 +03:00
|
|
|
"Tensor? bias_,"
|
2024-09-17 19:44:27 -04:00
|
|
|
"bool silu_activation,"
|
2024-09-30 00:35:58 +03:00
|
|
|
"Tensor? cache_seqlens_,"
|
2024-10-17 00:12:43 +08:00
|
|
|
"Tensor? conv_state_indices,"
|
|
|
|
"int pad_slot_id) -> ()");
|
2024-08-29 01:06:52 +03:00
|
|
|
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
|
|
|
|
|
|
|
ops.def(
|
|
|
|
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
|
|
|
"Tensor? bias_,"
|
2024-09-30 00:35:58 +03:00
|
|
|
"Tensor!? conv_states,"
|
|
|
|
"Tensor? query_start_loc,"
|
|
|
|
"Tensor? cache_indices,"
|
|
|
|
"Tensor? has_initial_state,"
|
2024-10-17 00:12:43 +08:00
|
|
|
"bool silu_activation,"
|
|
|
|
"int pad_slot_id) -> ()");
|
2024-08-29 01:06:52 +03:00
|
|
|
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
2025-02-14 20:30:42 -08:00
|
|
|
|
|
|
|
// Compute NVFP4 block quantized tensor.
|
|
|
|
ops.def(
|
|
|
|
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
|
|
|
" Tensor! output_scale, Tensor input_scale) -> ()");
|
|
|
|
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
#endif
|
|
|
|
|
|
|
|
// Quantized GEMM for GPTQ.
|
2024-09-11 15:52:19 -04:00
|
|
|
// Note: even though the C++ inferred schema is correct for this op, it seems
|
|
|
|
// to prevent the meta function registry.
|
|
|
|
ops.def(
|
|
|
|
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
|
|
|
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
2025-03-06 17:17:09 -05:00
|
|
|
"-> Tensor",
|
|
|
|
{stride_tag});
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
|
|
|
|
|
|
|
// Post processing for GPTQ.
|
|
|
|
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
|
|
|
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
|
|
|
|
|
|
|
// Compute FP8 quantized tensor for given scaling factor.
|
|
|
|
ops.def(
|
2024-11-08 16:20:08 -05:00
|
|
|
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
|
|
|
"()");
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.def(
|
2024-11-08 16:20:08 -05:00
|
|
|
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
|
|
|
|
"-> "
|
2024-06-09 16:23:30 -04:00
|
|
|
"()");
|
|
|
|
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
|
|
|
ops.def(
|
2024-11-08 16:20:08 -05:00
|
|
|
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
|
2024-09-11 15:52:19 -04:00
|
|
|
"Tensor! scale, Tensor? scale_ub) -> "
|
2024-07-17 21:38:35 -04:00
|
|
|
"()");
|
|
|
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
|
|
|
&dynamic_per_token_scaled_fp8_quant);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Compute int8 quantized tensor for given scaling factor.
|
|
|
|
ops.def(
|
2024-11-08 16:20:08 -05:00
|
|
|
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
2024-09-16 14:52:40 -04:00
|
|
|
"Tensor? azp) -> ()");
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
|
|
|
|
|
|
|
// Compute int8 quantized tensor and scaling factor
|
|
|
|
ops.def(
|
2024-11-08 16:20:08 -05:00
|
|
|
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
|
2024-09-16 14:52:40 -04:00
|
|
|
"Tensor!? azp) -> ()");
|
2024-06-09 16:23:30 -04:00
|
|
|
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
|
|
|
&dynamic_scaled_int8_quant);
|
2025-03-01 14:30:59 +08:00
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
|
|
|
ops.def(
|
|
|
|
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
|
|
|
"Tensor? b_zeros, "
|
|
|
|
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
|
|
|
|
"Tensor!? b_zeros_reorder, "
|
|
|
|
"int K, int N, int N_32align) -> ()");
|
|
|
|
// conditionally compiled so impl in source file
|
|
|
|
|
|
|
|
// AllSpark quantization ops
|
|
|
|
ops.def(
|
|
|
|
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
|
|
|
|
"Tensor? b_qzeros, "
|
|
|
|
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
|
|
|
|
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
|
|
|
// conditionally compiled so impl in source file
|
|
|
|
#endif
|
2024-06-09 16:23:30 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|
|
|
// Cache ops
|
|
|
|
// Swap in (out) the cache blocks from src to dst.
|
|
|
|
cache_ops.def(
|
|
|
|
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
|
|
|
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
|
|
|
|
|
|
|
// Copy the cache blocks from src to dst.
|
|
|
|
cache_ops.def(
|
2024-09-11 15:52:19 -04:00
|
|
|
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
|
|
|
"Tensor block_mapping) -> ()");
|
2024-06-09 16:23:30 -04:00
|
|
|
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
|
|
|
|
2025-02-04 21:22:24 -05:00
|
|
|
cache_ops.def(
|
|
|
|
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
|
|
|
|
cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Reshape the key and value tensors and cache them.
|
|
|
|
cache_ops.def(
|
|
|
|
"reshape_and_cache(Tensor key, Tensor value,"
|
|
|
|
" Tensor! key_cache, Tensor! value_cache,"
|
|
|
|
" Tensor slot_mapping,"
|
|
|
|
" str kv_cache_dtype,"
|
2025-01-23 13:04:03 -05:00
|
|
|
" Tensor k_scale, Tensor v_scale) -> ()");
|
2024-06-09 16:23:30 -04:00
|
|
|
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
|
|
|
|
|
|
|
// Reshape the key and value tensors and cache them.
|
|
|
|
cache_ops.def(
|
|
|
|
"reshape_and_cache_flash(Tensor key, Tensor value,"
|
|
|
|
" Tensor! key_cache,"
|
|
|
|
" Tensor! value_cache,"
|
|
|
|
" Tensor slot_mapping,"
|
2024-07-24 11:36:52 -07:00
|
|
|
" str kv_cache_dtype,"
|
2025-01-23 13:04:03 -05:00
|
|
|
" Tensor k_scale, Tensor v_scale) -> ()");
|
2024-06-09 16:23:30 -04:00
|
|
|
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
|
|
|
&reshape_and_cache_flash);
|
|
|
|
|
2025-01-31 02:49:37 -05:00
|
|
|
// Concat kv_c and k_pe and cache them.
|
|
|
|
cache_ops.def(
|
|
|
|
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
|
|
|
" Tensor! kv_cache,"
|
|
|
|
" Tensor slot_mapping,"
|
|
|
|
" str kv_cache_dtype,"
|
|
|
|
" Tensor scale) -> ()");
|
|
|
|
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Convert the key and value cache to fp8 data type.
|
|
|
|
cache_ops.def(
|
2024-09-11 15:52:19 -04:00
|
|
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
|
|
|
"str kv_cache_dtype) -> ()");
|
2024-06-09 16:23:30 -04:00
|
|
|
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
2025-02-21 18:30:12 -05:00
|
|
|
|
|
|
|
// Gather cache blocks from src_cache to dst.
|
|
|
|
cache_ops.def(
|
|
|
|
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
|
|
|
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
|
|
|
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
2024-06-09 16:23:30 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
|
|
|
// Cuda utils
|
|
|
|
|
|
|
|
// Gets the specified device attribute.
|
2024-09-11 15:52:19 -04:00
|
|
|
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
|
|
|
cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
// Gets the maximum shared memory per block device attribute.
|
2024-09-11 15:52:19 -04:00
|
|
|
cuda_utils.def(
|
|
|
|
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
2024-06-09 16:23:30 -04:00
|
|
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
|
|
|
&get_max_shared_memory_per_block_device_attribute);
|
|
|
|
}
|
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|
|
|
// Custom all-reduce kernels
|
2024-09-11 15:52:19 -04:00
|
|
|
custom_ar.def(
|
2024-11-06 23:50:47 -08:00
|
|
|
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
|
|
|
"int rank, bool full_nvlink) -> int");
|
2024-06-09 16:23:30 -04:00
|
|
|
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
|
|
|
custom_ar.def(
|
2024-11-06 23:50:47 -08:00
|
|
|
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
|
|
|
"int reg_buffer_sz_bytes) -> ()");
|
|
|
|
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
custom_ar.def("dispose", &dispose);
|
|
|
|
custom_ar.def("meta_size", &meta_size);
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
custom_ar.def("register_buffer", ®ister_buffer);
|
2024-06-09 16:23:30 -04:00
|
|
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
|
|
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|