From 1e60f87bb37bc28410e6cf6e9030e9a28ad49d12 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 22 Jan 2025 02:30:28 +0800 Subject: [PATCH] [Kernel] fix moe_align_block_size error condition (#12239) Signed-off-by: Jinzhen Lin --- csrc/moe/moe_align_sum_kernels.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 715a1b42841f2..d609ce1697df3 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, (num_experts + 1) * sizeof(int32_t); bool use_global_memory = false; - bool use_i16 = false; // Use uint16_t for shared memory token counts - if (shared_mem_i16 > device_max_shared_mem) { - use_global_memory = true; - } else if (shared_mem_i32 > device_max_shared_mem && + bool use_i16 = false; // Use uint16_t for shared memory token counts + if (shared_mem_i32 < device_max_shared_mem) { + // Do nothing in this case. We're all set to use int32_t token counts + } else if (shared_mem_i16 < device_max_shared_mem && topk_ids.numel() <= 65535) { // when nelements of topk_ids is smaller than 65535 (max value of uint16), // element value of token_cnts would also smaller than 65535, // so we can use uint16 as dtype of token_cnts use_i16 = true; + } else { + use_global_memory = true; } if (use_global_memory) {