[Kernel] fix moe_align_block_size error condition (#12239)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
parent
9705b90bcf
commit
1e60f87bb3
@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
(num_experts + 1) * sizeof(int32_t);
|
||||
|
||||
bool use_global_memory = false;
|
||||
bool use_i16 = false; // Use uint16_t for shared memory token counts
|
||||
if (shared_mem_i16 > device_max_shared_mem) {
|
||||
use_global_memory = true;
|
||||
} else if (shared_mem_i32 > device_max_shared_mem &&
|
||||
bool use_i16 = false; // Use uint16_t for shared memory token counts
|
||||
if (shared_mem_i32 < device_max_shared_mem) {
|
||||
// Do nothing in this case. We're all set to use int32_t token counts
|
||||
} else if (shared_mem_i16 < device_max_shared_mem &&
|
||||
topk_ids.numel() <= 65535) {
|
||||
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
|
||||
// element value of token_cnts would also smaller than 65535,
|
||||
// so we can use uint16 as dtype of token_cnts
|
||||
use_i16 = true;
|
||||
} else {
|
||||
use_global_memory = true;
|
||||
}
|
||||
|
||||
if (use_global_memory) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user