Skip to content

Commit

Permalink
match tests for the MOE layers with main.
Browse files Browse the repository at this point in the history
  • Loading branch information
vllmellm committed Jan 15, 2025
1 parent 7c05f3e commit af684f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 88 deletions.
93 changes: 6 additions & 87 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,6 @@
TOP_KS = [2, 6]


def permute_weight(x: torch.Tensor) -> torch.Tensor:
## Hardcode BLOCK_K and BLOCK_N
BK = 128
BN = 128
x_ = x.clone()
x_ = x_.view(x.shape[0], x.shape[1] // BN, BN // 16, 16, x.shape[2] // BK,
BK // 32, 4, 8)
x_ = x_.permute(0, 1, 5, 2, 6, 4, 3, 7)
x_ = x_.contiguous()
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2])
return x_


@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand All @@ -65,9 +52,9 @@ def test_fused_moe(

# Pad the input if use padding
if envs.VLLM_MOE_PADDING:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., :-128]
w1 = F.pad(w1, (0, 128), "constant", 0)
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., :-128]
w2 = F.pad(w2, (0, 128), "constant", 0)
torch.cuda.empty_cache()
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
Expand All @@ -79,74 +66,6 @@ def test_fused_moe(
rtol=0)


@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237])
@pytest.mark.parametrize("n", [14336])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_amd_moe_1(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
if n == k:
pytest.skip()
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
if envs.VLLM_MOE_SHUFFLE:
w1_shuffled = permute_weight(w1.data)
w2_shuffled = permute_weight(w2.data)

score = torch.randn((m, e), device='cuda', dtype=dtype)
triton_output = fused_moe(a,
w1_shuffled,
w2_shuffled,
score,
topk,
renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0)


@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237])
@pytest.mark.parametrize("n", [4096])
@pytest.mark.parametrize("k", [14336])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_amd_moe_2(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
if n == k:
pytest.skip()
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
if envs.VLLM_MOE_SHUFFLE:
w1_shuffled = permute_weight(w1.data)
w2_shuffled = permute_weight(w2.data)

score = torch.randn((m, e), device='cuda', dtype=dtype)
triton_output = fused_moe(a,
w1_shuffled,
w2_shuffled,
score,
topk,
renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=2e-1, rtol=0)


@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
Expand Down Expand Up @@ -181,13 +100,13 @@ def test_mixtral_moe(dtype: torch.dtype):

# pad the weight if using padding
if envs.VLLM_MOE_PADDING:
vllm_moe.experts.w13_weight = Parameter(
F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0),
requires_grad=False)[..., :-128]
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0),
requires_grad=False)[..., :-128]
requires_grad=False)
torch.cuda.empty_cache()

# Run forward passes for both MoE blocks
Expand Down
1 change: 0 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1
VLLM_MOE_PADDING: bool = True
VLLM_FP8_PADDING: bool = True
VLLM_MOE_SHUFFLE: bool = False
FUSED_MOE_PERSISTENT: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
Expand Down

0 comments on commit af684f9

Please sign in to comment.