diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 5791dd832d089..e743cfa193ade 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1014,23 +1014,21 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { - if (c_gl_wr / c_gl_stride < prob_m) { - int row = sh_sorted[c_gl_wr / c_gl_stride]; - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - // HERE we read from sh, how is the access different? - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - // printf("c offset: %d at row %d from %d (%d %d)\n", off, row, c_gl_wr / c_gl_stride, threadIdx.x, blockIdx.x); - for (int j = 0; j < 8; ++j) { - // printf("csrc %f\n", __half2float(csrc[j])); - // printf("ctrg %f\n", __half2float(ctrg[j])); - // printf("csrc %f, ctrg %f\n", __half2float(csrc[j]), __half2float(ctrg[j])); - __half old = ctrg[j]; - ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j])); - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; + int row = sh_sorted[c_gl_wr / c_gl_stride]; + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + // HERE we read from sh, how is the access different? + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + // printf("c offset: %d at row %d from %d (%d %d)\n", off, row, c_gl_wr / c_gl_stride, threadIdx.x, blockIdx.x); + for (int j = 0; j < 8; ++j) { + // printf("csrc %f\n", __half2float(csrc[j])); + // printf("ctrg %f\n", __half2float(ctrg[j])); + // printf("csrc %f, ctrg %f\n", __half2float(csrc[j]), __half2float(ctrg[j])); + __half old = ctrg[j]; + ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j])); } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; } } }; @@ -1395,10 +1393,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { // printf("init ptrs for expert %d and gs %d\n", expert_idx, group_size); const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B;// + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; int4* C_ptr = (int4*)C; int* sorted_ids_ptr = (int*)sorted_ids;// + moe_block_size * expert_idx; - const int4* s_ptr = (const int4*)s;// + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx; + const int4* s_ptr = (const int4*)s + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx; // const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; // const int* perm_ptr = (const int*)perm; int4* red_tmp_ptr = (int4*)red_tmp; diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 0f7c6be3f9cbd..b23d6b2797d42 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -6,14 +6,14 @@ import pytest import torch import numpy +from typing import List from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_moe, fused_marlin_moe from vllm.model_executor.models.mixtral import MixtralMoE -""" def torch_moe(a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -24,13 +24,13 @@ def torch_moe(a, w1, w2, score, topk): topk_ids = topk_ids.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i + # print(mask) if mask.sum(): out[mask] = SiluAndMul()( a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - @pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -59,8 +59,8 @@ def test_fused_moe( [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() def test_mixtral_moe(dtype: torch.dtype): - "Make sure our Mixtral MoE implementation agrees with the one from - huggingface." + "Make sure our Mixtral MoE implementation agrees with the one from" + "huggingface." # Instantiate our and huggingface's MoE blocks config = MixtralConfig() @@ -102,8 +102,6 @@ def test_mixtral_moe(dtype: torch.dtype): rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) -""" - def get_marlin_perms(): perm = [] for i in range(32): @@ -162,6 +160,8 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=numpy.uint32) + print("PACKED:", q_w.shape, ">>", q_packed.shape) + for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i @@ -236,6 +236,7 @@ def marlin_quantize( num_bits: int, group_size: int, ): + print("START:", w.size(), num_bits, group_size) perm, scale_perm, scale_perm_single = get_marlin_perms() print("SHAPE:", w.shape) @@ -249,6 +250,7 @@ def marlin_quantize( # Quantize w_ref, q_w, s = quantize_weights(w, num_bits, group_size) + print("interm:", w_ref.size(), q_w.size(), s.size()) #TODO experts # Reformat to marlin @@ -268,35 +270,79 @@ def marlin_quantize( import vllm._moe_C as moe_kernels +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + def test_fused_marlin_moe(): - m = 16 - n = 128 - k = 128 - e = 1 + m = 256 + n = 512 + k = 512 + e = 8 + topk = 2 dtype = torch.float16 moe_block_size = 16 + e_m = 8 a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((k, n), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + w_m = torch.randn((e_m, k, n), device='cuda', dtype=dtype) / 10 a_2d = a.view(-1, a.shape[-1]) - w_ref, qweight, scales = marlin_quantize(w1, 4, -1) - qweight = qweight.unsqueeze(0) - scales = scales.unsqueeze(0) + w_refs = [] + qweights = [] + scaless = [] + + for i in range(w_m.shape[0]): + w_ref, qweight, scales = marlin_quantize(w_m[i], 4, -1) + w_refs.append(w_ref) + qweights.append(qweight) + scaless.append(scales) + + w_ref = stack_and_dev(w_refs) + qweight = stack_and_dev(qweights) + scales = stack_and_dev(scaless) + print("w_ref size:", w_ref.size()) + print("qweight size:", qweight.size()) print("scales size:", scales.size()) # Allocate marlin workspace max_workspace_size = (n // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, + device="cuda", requires_grad=False) + shuffles = torch.range(0, m - 1, dtype=torch.int) sorted_ids = torch.full([m + (moe_block_size - 1)], m, dtype=torch.int).cuda() + sorted_ids[:m] = shuffles - # score = torch.randn((m, e), device='cuda', dtype=dtype) - moe_kernels.marlin_gemm_moe(a_2d, qweight, sorted_ids, scales, workspace, m, n, k) - # triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) - # torch_output = torch_moe(a, w1, w2, score, topk) - # assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) \ No newline at end of file + score = torch.randn((m, e), device='cuda', dtype=dtype) + marlin_output = moe_kernels.marlin_gemm_moe(a, qweight, sorted_ids, scales, workspace, m, n, k) + torch_output = torch_moe(a, w1, w2, score, topk) + # assert torch.allclose(marlin_output, torch_output, atol=1e-2, rtol=0) + +@pytest.mark.parametrize("m", [512]) #, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048]) #, 256, 1024]) +@pytest.mark.parametrize("k", [128]) #, 511, 1024]) +@pytest.mark.parametrize("e", [8]) #, 64]) +@pytest.mark.parametrize("topk", [2]) #, 6]) +@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) +def test_fused_marlin_moe_2( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + + score = torch.randn((m, e), device='cuda', dtype=dtype) + triton_output = fused_marlin_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk) + assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c89c62b..1bd1b2876d838 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_moe, fused_marlin_moe, get_config_file_name) __all__ = [ "fused_moe", + "fused_marlin_moe", "get_config_file_name", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3cb0419404625..8ca8bf22d1492 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,11 +2,12 @@ import functools import json import os -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, List import torch import triton import triton.language as tl +import numpy from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -321,7 +322,7 @@ def fused_moe( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -477,3 +478,332 @@ def fused_moe( out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + +def get_marlin_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm) + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +def marlin_permute_weights(q_w, size_k, size_n, num_bits, perm): + tile = 16 + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + # print("NUMEL", perm.numel()) + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits, perm) + + # Pack + pack_factor = 32 // num_bits + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + + print("PACKED:", q_w.shape, ">>", q_packed.shape) + + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, scale_perm_single): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + +def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + ) + +# TODO rewrite these transformations for multi expert +def marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + print("START:", w.size(), num_bits, group_size) + perm, scale_perm, scale_perm_single = get_marlin_perms() + + print("SHAPE:", w.shape) + #TODO experts dim + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize + w_ref, q_w, s = quantize_weights(w, num_bits, group_size) + print("interm:", w_ref.size(), q_w.size(), s.size()) + + #TODO experts + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, scale_perm_single) + + marlin_q_w = marlin_q_w + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, K = hidden_states.shape + E, N, _ = w1.shape + + if is_hip(): + # The MoE kernels are not yet supported on ROCm. + routing_weights = torch.softmax(gating_output, + dim=-1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + else: + import vllm._moe_C as moe_kernels + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], + "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + + if M <= E: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + + # w_refs1 = [] + qweights1 = [] + scaless1 = [] + + for i in range(w1.shape[0]): + _, qweight, scales = marlin_quantize(w1[i], 4, -1) + # w_refs1.append(w_ref) + qweights1.append(qweight) + scaless1.append(scales) + + # w_ref1 = stack_and_dev(w_refs1) + qweight1 = stack_and_dev(qweights1) + scales1 = stack_and_dev(scaless1) + + # w_refs2 = [] + qweights2 = [] + scaless2 = [] + + for i in range(w2.shape[0]): + _, qweight, scales = marlin_quantize(w2[i], 4, -1) + # w_refs1.append(w_ref) + qweights2.append(qweight) + scaless2.append(scales) + + # w_ref2 = stack_and_dev(w_refs2) + qweight2 = stack_and_dev(qweights2) + scales2 = stack_and_dev(scaless2) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache1 = moe_kernels.marlin_gemm_moe(hidden_states, qweight1, sorted_token_ids, scales1, workspace, M, N, K) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + intermediate_cache3 = moe_kernels.marlin_gemm_moe(intermediate_cache2, qweight2, sorted_token_ids, scales2, workspace, M, N, K) + + if inplace: + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1)