2024-02-05 17:38:02 -08:00
|
|
|
#pragma once
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
#include <torch/all.h>
|
2024-02-05 17:38:02 -08:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
|
|
|
torch::Tensor& token_expert_indices,
|
|
|
|
torch::Tensor& gating_output);
|
2024-10-24 17:37:52 -05:00
|
|
|
|
|
|
|
void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
|
|
|
|
|
|
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
|
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
|
|
|
torch::Tensor experts_ids,
|
|
|
|
torch::Tensor num_tokens_post_pad);
|
2025-02-02 21:09:50 -08:00
|
|
|
|
|
|
|
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
|
|
int64_t block_size,
|
|
|
|
torch::Tensor sorted_token_ids,
|
|
|
|
torch::Tensor experts_ids,
|
|
|
|
torch::Tensor num_tokens_post_pad);
|