
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
31 lines
1.4 KiB
C++
31 lines
1.4 KiB
C++
#pragma once
|
|
|
|
#include <torch/all.h>
|
|
|
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
|
torch::Tensor& token_expert_indices,
|
|
torch::Tensor& gating_output);
|
|
|
|
void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
|
|
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
|
torch::Tensor experts_ids,
|
|
torch::Tensor num_tokens_post_pad);
|
|
|
|
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
int64_t block_size,
|
|
torch::Tensor sorted_token_ids,
|
|
torch::Tensor experts_ids,
|
|
torch::Tensor num_tokens_post_pad);
|
|
|
|
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|
torch::Tensor b_qweight, torch::Tensor b_scales,
|
|
std::optional<torch::Tensor> b_qzeros,
|
|
std::optional<torch::Tensor> topk_weights,
|
|
torch::Tensor sorted_token_ids,
|
|
torch::Tensor expert_ids,
|
|
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
|
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
|
int64_t BLOCK_SIZE_K, int64_t bit);
|