From f1a56669f59cc1bf138523c1247dde9db7f28e56 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 17 Jan 2025 16:27:58 +0000 Subject: [PATCH] Start working on fused moe cutlass implementation Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 196 +++++++++++++++++- csrc/cpu/torch_bindings.cpp | 8 + .../epilogue/broadcast_load_epilogue_c3x.hpp | 6 +- csrc/ops.h | 6 + .../cutlass_w8a8/grouped_mm_c3x.cu | 47 +++++ .../cutlass_w8a8/scaled_mm_entry.cu | 16 ++ csrc/torch_bindings.cpp | 7 + tests/kernels/test_cutlass_moe.py | 145 +++++++++++++ 8 files changed, 422 insertions(+), 9 deletions(-) create mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index be401cec03c66..67923262a5855 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_moe, fused_topk, fused_experts) DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"] # "nm-testing/deepseekv2-lite", @@ -17,6 +19,11 @@ PER_ACT_TOKEN_OPTS = [False, True] PER_OUT_CH_OPTS = [False, True] +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + def grouped_gemm(a_g_tensors: List[torch.Tensor], b_g_tensors: List[torch.Tensor], out_g_tensors: List[torch.Tensor], @@ -33,6 +40,30 @@ def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], b = b_tensors[g] out = torch.mm(a, b) out_tensors[g] = out + +def cutlass_fused(a_tensors: List[torch.Tensor], + w1_tensors: List[torch.Tensor], + w2_tensors: List[torch.Tensor], + c1_tensors: List[torch.Tensor], + c2_tensors: List[torch.Tensor], + c2_tensors_fp8: List[torch.Tensor], + c3_tensors: List[torch.Tensor], + a_scales: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + c2_scales: List[torch.Tensor], + num_groups: int): + # output_dtype = c3_tensors[0].dtype + N = c2_tensors[0].shape[1] + ops.cutlass_grouped_mm(c1_tensors, a_tensors, w1_tensors, + a_scales, w1_scales) + # TODO make this work as it should + for idx in range(num_groups): + torch.ops._C.silu_and_mul(c2_tensors[idx], c1_tensors[idx].view(-1, N)) + print(c2_tensors[idx]) + c2_tensors_fp8[idx] = to_fp8(c2_tensors[idx].half()) + ops.cutlass_grouped_mm(c3_tensors, c2_tensors, w2_tensors, + c2_scales, w2_scales) def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, per_act_token: bool, per_out_ch: bool, @@ -47,11 +78,6 @@ def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, device = "cuda" out_dtype = torch.half - - def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) a_tensors = [] b_tensors = [] @@ -142,6 +168,164 @@ def to_fp8(tensor: torch.Tensor): description="baseline_gemm", ).blocked_autorange(min_run_time=min_run_time)) +def bench_run_moe(results: List[benchmark.Measurement], model: str, num_groups: int, + per_act_token: bool, per_out_ch: bool, + mkn: List[Tuple[int, int, int]]): + label = "Quant Matmul" + + sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_groups, per_act_token, + per_out_ch, mkn)) + + print(f"Testing: {sub_label}") + + device = "cuda" + out_dtype = torch.bfloat16 + + def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + m_tot = sum([elem[0] for elem in mkn]) + k_g = mkn[0][1] + n_g = mkn[0][2] + + a_tensors = [] + w1_tensors = [] + w2_tensors = [] + c1_tensors = [] + c2_tensors = [] + c2_tensors_fp8 = [] + c3_tensors = [] + a_scales = [] + w1_scales = [] + w2_scales = [] + c2_scales = [] + + a = torch.randn((m_tot, k_g), device=device, dtype=out_dtype) + w1 = torch.randn((num_groups, 2 * n_g, k_g), device=device, dtype=out_dtype) + w2 = torch.randn((num_groups, k_g, n_g), device=device, dtype=out_dtype) + scored_output = torch.randn((m_tot, num_groups), device="cuda", dtype=out_dtype) + topk = 2 + # triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + + #TODO grouped topk for deepseek + topk_weights, topk_ids = fused_topk(a, scored_output, topk, renormalize=True) + fused_experts(a, w1, w2, topk_weights, topk_ids) + topk_ids_cpu = topk_ids.cpu() + + occurrences = [0] * num_groups + expert_offsets = [0] * (num_groups + 1) + for id in topk_ids_cpu.flatten(): + occurrences[id] += 1 + + for e in range(num_groups): + expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + + print(expert_offsets, m_tot) + + a = torch.randn((m_tot, k_g)) + a_group[0] = a[sorted_token_ids[0]] + + # TODO + # create full input tensor m_tot x k_g x topk + # get shuffle data like sorted_token_ids etc. + # create view + + for g in range(num_groups): + m_g = occurrences[g] + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + w1_g = to_fp8(torch.randn((2 * n_g, k_g), device=device).t()) + w2_g = to_fp8(torch.randn((k_g, n_g), device=device).t()) + c1_g = torch.zeros((m_g, 2 * n_g), device=device, dtype=torch.bfloat16) + c2_g = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) + c2_g_fp8 = to_fp8(torch.zeros((m_g, n_g), device=device)) + c3_g = torch.zeros((m_g, k_g), device=device, dtype=torch.bfloat16) + # m_a_scales = m_g if per_act_token else 1 + # n_b_scales = n_g if per_out_ch else 1 + m_scales = 1 + n2_scales = 1 + k_scales = 1 + scale_a = (torch.randn((m_scales, 1), device=device, + dtype=torch.float32)) + scale_w1 = (torch.randn((n2_scales, 1), device=device, + dtype=torch.float32)) + scale_w2 = (torch.randn((k_scales, 1), device=device, + dtype=torch.float32)) + scale_c2 = (torch.randn((m_scales, 1), device=device, + dtype=torch.float32)) + + a_tensors.append(a_g) + w1_tensors.append(w1_g) + w2_tensors.append(w2_g) + c1_tensors.append(c1_g) + c2_tensors.append(c2_g) + c2_tensors_fp8.append(c2_g_fp8) + c3_tensors.append(c3_g) + a_scales.append(scale_a) + w1_scales.append(scale_w1) + w2_scales.append(scale_w2) + c2_scales.append(scale_c2) + + globals = { + # Gen params + "num_groups": num_groups, + # Grouped gemm params + "a_tensors": a_tensors, + "w1_tensors": w1_tensors, + "w2_tensors": w2_tensors, + "c1_tensors": c1_tensors, + "c2_tensors": c2_tensors, + "c2_tensors_fp8": c2_tensors_fp8, + "c3_tensors": c3_tensors, + "a_scales": a_scales, + "w1_scales": w1_scales, + "w2_scales": w2_scales, + "c2_scales": c2_scales, + # Triton params (fused_moe) + "a": a, + "w1": w1, + "w2": w2, + "scored_output": scored_output, + "topk": topk, + # Kernels + "fused_moe": fused_moe, + "cutlass_fused": cutlass_fused, + } + + min_run_time = 1 + num_warmup = 5 + + # Warmup triton + for _ in range(num_warmup): + fused_moe(a, w1, w2, scored_output, topk, renormalize=False) + + results.append( + benchmark.Timer( + stmt="fused_moe(a, w1, w2, scored_output, topk, renormalize=False)", + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup cutlass + for _ in range(num_warmup): + cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, + c2_tensors_fp8, c3_tensors, a_scales, w1_scales, + w2_scales, c2_scales, num_groups) + + results.append( + benchmark.Timer( + stmt= + "cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, c2_tensors_fp8, c3_tensors, a_scales, w1_scales, w2_scales, c2_scales, num_groups)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="baseline_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): @@ -165,7 +349,7 @@ def main(args): for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = [(size_m, size_k, size_n)] * num_groups - bench_run(results, model, num_groups, per_act_token, + bench_run_moe(results, model, num_groups, per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index a720348dee3e2..96ddcab7cea26 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -125,6 +125,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor out_offsets, Tensor a_offsets, " " Tensor b_offsets) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + + ops.def( + "compute_expert_offsets(Tensor! trg_a_ptrs," + " Tensor! a, Tensor topk_ids," + " Tensor! expert_offsets, SymInt num_experts) -> ()"); + ops.impl("compute_expert_offsets", torch::kCUDA, + &compute_expert_offsets); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 9f049efd07b46..ad33eec9ef8fe 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -424,9 +424,9 @@ struct Sm90ColOrScalarBroadcast { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - if (threadIdx.x ==128){ - printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - } + // if (threadIdx.x ==128){ + // printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/ops.h b/csrc/ops.h index 736a40091f032..d7ec0e0f91283 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -161,6 +161,12 @@ void cutlass_grouped_mm(c10::List const& out_tensors, c10::List const& a_scales, c10::List const& b_scales); +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index caa7edb888a36..835a144aed4b1 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -321,3 +321,50 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales); } + +__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, + cutlass::float_e4m3_t* base_a_ptr, + const int* __restrict__ topk_ids, + int64_t* expert_offsets, + int topk_length) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int64_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + trg_a_ptrs[i] = base_a_ptr + tot_offset; + tot_offset += expert_offsets[i + 1]; + expert_offsets[i + 1] = tot_offset; + } + } +} + +// For a given "a" of size [M,K] performs a permutation of the M rows based +// on the given "perm" indices. +__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, + int const* __restrict__ perm_int_ptr, + cutlass::float_e4m3_t* __restrict__ out_ptr, + int size_m, int size_k, int block_rows) { + // TODO +} + +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts) { + get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(), + (cutlass::float_e4m3_t*)a.data_ptr(), + (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), + topk_ids.numel()); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index e60b64d7797bf..d9d2a91d0659f 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -36,6 +36,13 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, c10::List const& a_scales, c10::List const& b_scales); + +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); + #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -159,6 +166,15 @@ void cutlass_grouped_mm(c10::List const& out_tensors, b_scales); } +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts) { + compute_expert_offsets_caller(trg_a_ptrs, a, topk_ids, expert_offsets, + num_experts); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b862144aa16f5..65d48c7f14659 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -329,6 +329,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor[] b_scales) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + ops.def( + "compute_expert_offsets(Tensor! trg_a_ptrs," + " Tensor! a, Tensor topk_ids," + " Tensor! expert_offsets, SymInt num_experts) -> ()"); + ops.impl("compute_expert_offsets", torch::kCUDA, + &compute_expert_offsets); + // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability ops.def( diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py new file mode 100644 index 0000000000000..083de75e1d34f --- /dev/null +++ b/tests/kernels/test_cutlass_moe.py @@ -0,0 +1,145 @@ +import pytest +import torch +from transformers import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from typing import List + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) +from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +NUM_EXPERTS = [8, 64] +TOP_KS = [2, 6] + +# TODO move to a better file later +# TODO handle scores +def cutlass_moe(a: torch.Tensor, + a_q: torch.Tensor, + a_scale: torch.Tensor, + w1_qs: List[torch.Tensor], + w2_qs: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +): + # TODO look at the code in benchmark_grouped_gemm_cutlass.py + # and get the relevant parts + # (also the fused_moe function) + + num_groups = len(w1_qs) + topk = topk_ids.shape[1] + num_tokens = topk_ids.shape[0] + + # TODO make this GPU only + occurrences = [0] * num_groups + expert_offsets = [0] * (num_groups + 1) + for id in topk_ids.cpu().flatten(): + occurrences[id] += 1 + for e in range(num_groups): + expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + + # TODO duplicate A rows topk times + # compute sorted_token_ids (argsort?) + # shuffle A according to this so each group input is contiguous + + # print(topk_ids) + # print(expert_offsets) + a_map = topk_ids.flatten().argsort() + rep_a_q = a_q.repeat_interleave(topk, dim=0) + + print(a_map) + print(rep_a_q) + + a_q_s = [] + for e in range(num_groups): + a_q_s.append(rep_a_q[a_map[expert_offsets[e]:expert_offsets[e+1]]]) + print(a_q_s) + return + # get a_map and expert_indices on device + + # TODO shuffle rep_a_q according to a_map + # get a_ptrs = a + expert_indices[:-1] + + a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") + expert_offsets = torch.empty((num_groups + 1), dtype=torch.int64, device="cuda") + # TODO might need to call it from inside cutlass code? + # help(ops) + + # print(a_ptrs) + # print(rep_a_q) + print(topk_ids) + # print(expert_offsets) + # print(num_groups) + torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), + expert_offsets, num_groups) + print(a_ptrs) + print(expert_offsets) + +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) +# @pytest.mark.parametrize("n", [128, 2048]) +# @pytest.mark.parametrize("k", [128, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("m", [10]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +def test_cutlass_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + current_platform.seed_everything(7) + + dtype = torch.bfloat16 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + a_q, a_scale = ops.scaled_fp8_quant(a) + + w1_qs = [] + w2_qs = [] + w1_scales = [] + w2_scales = [] + + for expert in range(e): + w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) + w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) + w1_qs.append(w1_q) + w2_qs.append(w2_q) + w1_scales.append(w1_scale) + w2_scales.append(w2_scale) + + # (assume score is a vector of ones for now) + score = torch.ones((m, e), device="cuda", dtype=dtype) + + e_range = torch.full((m, e), 1.0 / e) + topk_ids = torch.multinomial(e_range, topk).int().sort()[0] + topk_weights = torch.rand((m, topk)) + + torch_output = torch_moe(a, w1, w2, score, topk) + cutlass_output = cutlass_moe(a, a_q, a_scale, w1_qs, w2_qs, w1_scales, + w2_scales, topk_weights, topk_ids) + + # torch.testing.assert_close(torch_output, + # cutlass_output, + # atol=2e-2, + # rtol=0)