2024-08-02 16:51:58 -04:00
|
|
|
#include "core/registration.h"
|
2024-06-09 16:23:30 -04:00
|
|
|
#include "moe_ops.h"
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|
|
|
// Apply topk softmax to the gating outputs.
|
|
|
|
m.def(
|
|
|
|
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
|
|
|
"token_expert_indices, Tensor gating_output) -> ()");
|
|
|
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
|
|
|
}
|
|
|
|
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|