diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 41aa6ec1601ee..4d9de33ff6a0d 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -116,15 +116,18 @@ def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T - qzeros = qzeros.T.contiguous().to(torch.uint8) + if has_zp: + qzeros = qzeros.T.contiguous().to(torch.uint8) if weight_bits == 4: qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] - qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] + if has_zp: + qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] w_ref[expert_id] = weight w_qweight[expert_id] = qweight w_scales[expert_id] = scales - w_qzeros[expert_id] = qzeros + if has_zp: + w_qzeros[expert_id] = qzeros triton_output = fused_moe(a, w1_qweight, @@ -136,8 +139,8 @@ def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, use_int8_w8a16=weight_bits == 8, w1_scale=w1_scales, w2_scale=w2_scales, - w1_zp=w1_qzeros, - w2_zp=w2_qzeros, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4b36262c0ca53..eda3ff96bc87b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -457,8 +457,7 @@ def moe_align_block_size( dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad, - num_experts >= 256) + expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad