[Kernel] fix moe_align_block_size error condition (#12239)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin 2025-01-22 02:30:28 +08:00 committed by GitHub
parent 9705b90bcf
commit 1e60f87bb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
(num_experts + 1) * sizeof(int32_t); (num_experts + 1) * sizeof(int32_t);
bool use_global_memory = false; bool use_global_memory = false;
bool use_i16 = false; // Use uint16_t for shared memory token counts bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i16 > device_max_shared_mem) { if (shared_mem_i32 < device_max_shared_mem) {
use_global_memory = true; // Do nothing in this case. We're all set to use int32_t token counts
} else if (shared_mem_i32 > device_max_shared_mem && } else if (shared_mem_i16 < device_max_shared_mem &&
topk_ids.numel() <= 65535) { topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16), // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535, // element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts // so we can use uint16 as dtype of token_cnts
use_i16 = true; use_i16 = true;
} else {
use_global_memory = true;
} }
if (use_global_memory) { if (use_global_memory) {