Skip to content

Commit

Permalink
[Kernel] fix moe_align_block_size error condition (#12239)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhen Lin <[email protected]>
  • Loading branch information
jinzhen-lin authored Jan 21, 2025
1 parent 9705b90 commit 1e60f87
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
(num_experts + 1) * sizeof(int32_t);

bool use_global_memory = false;
bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i16 > device_max_shared_mem) {
use_global_memory = true;
} else if (shared_mem_i32 > device_max_shared_mem &&
bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i32 < device_max_shared_mem) {
// Do nothing in this case. We're all set to use int32_t token counts
} else if (shared_mem_i16 < device_max_shared_mem &&
topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
use_i16 = true;
} else {
use_global_memory = true;
}

if (use_global_memory) {
Expand Down

0 comments on commit 1e60f87

Please sign in to comment.