[Kernel] fix moe_align_block_size error condition (#12239)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin 2025-01-22 02:30:28 +08:00 committed by GitHub
parent 9705b90bcf
commit 1e60f87bb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
(num_experts + 1) * sizeof(int32_t);
bool use_global_memory = false;
bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i16 > device_max_shared_mem) {
use_global_memory = true;
} else if (shared_mem_i32 > device_max_shared_mem &&
bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i32 < device_max_shared_mem) {
// Do nothing in this case. We're all set to use int32_t token counts
} else if (shared_mem_i16 < device_max_shared_mem &&
topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
use_i16 = true;
} else {
use_global_memory = true;
}
if (use_global_memory) {