#pragma once #include #include #include "core/scalar_type.hpp" #include torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { // Ensure tensor is on CUDA if (!tensor.is_cuda()) { throw std::runtime_error("Tensor must be on CUDA device"); } // Get the raw data pointer void* data_ptr = tensor.data_ptr(); // Get tensor sizes and strides std::vector sizes = tensor.sizes().vec(); std::vector strides = tensor.strides().vec(); // Get tensor options (dtype, device) auto options = tensor.options(); // Create a new tensor from the raw data pointer auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); return new_tensor; } void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& scale, double epsilon); void rms_norm_dynamic_per_token_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& weight, torch::Tensor& scales, double const epsilon, std::optional scale_ub, std::optional residual); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, double threshold); void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, int64_t block_size, torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, torch::Tensor& input_positions, torch::Tensor& seq_lens, torch::Tensor& slot_mapping, torch::Tensor& block_tables); void advance_step_flashinfer( int64_t num_seqs, int64_t num_queries, int64_t block_size, torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, torch::Tensor& input_positions, torch::Tensor& seq_lens, torch::Tensor& slot_mapping, torch::Tensor& block_tables, torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, const std::vector& codebook_partition_sizes, const std::optional& bias); torch::Tensor aqlm_dequant( const torch::Tensor& codes, const torch::Tensor& codebooks, const std::vector& codebook_partition_sizes); torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); #endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); #ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& e, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, c10::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale); void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, c10::optional const& scale_ub); void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, const c10::optional& D_, const c10::optional& z_, const c10::optional& delta_bias_, bool delta_softplus, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, const c10::optional& bias_, bool silu_activation, const c10::optional& cache_seqlens_, const c10::optional& conv_state_indices_, int64_t pad_slot_id); void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, const c10::optional& conv_states, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); void dispose(fptr_t _fa); int64_t meta_size(); void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); #endif