diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 07eada1af18b8..0661525e261e1 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -5,31 +5,48 @@ import torch import torch.nn.functional as F import triton -import pandas as pd from tqdm import tqdm + import vllm._moe_C as moe_kernels from vllm._C import ops - from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name, - moe_align_block_size, - invoke_fused_moe_kernel) + invoke_fused_moe_kernel, + moe_align_block_size) -os.environ['CUDA_VISIBLE_DEVICES'] = '2' -os.environ['HIP_FORCE_DEV_KERNARG'] = '1' -os.environ['DEBUG_CLR_GRAPH_PACKET_CAPTURE'] = '1' -os.environ['OPTIMIZE_EPILOGUE'] = '1' +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +os.environ["HIP_FORCE_DEV_KERNARG"] = "1" +os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" +os.environ["OPTIMIZE_EPILOGUE"] = "1" TP = 8 + def main(): method = fused_moe for bs in [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ]: run_grid(bs, method=method) + ## Utilize method from rocm/Triton tuning script def get_full_tuning_space(): configs = [] @@ -55,26 +72,36 @@ def get_full_tuning_space(): # for split_k in split_k_range: for num_stages in num_stage_range: for waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for (matrix_instr_nonkdim + ) in matrix_instr_nonkdim_range: for kpack in kpack_range: - configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}) + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": + matrix_instr_nonkdim, + "kpack": kpack, + }) return configs + ## Utilize method from rocm/Triton tuning script def prune_configs(M, N, K, configs): pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - if M < 32 or N < 32: - mfma = 16 - else: - mfma = 32 + mfma = 16 if M < 32 or N < 32 else 32 # TODO (zhanglx): figure out the boundary between large and small gemms large_gemm = False - if M >= 2048 and N >=2048: + if M >= 2048 and N >= 2048: large_gemm = True for config in configs: @@ -83,28 +110,29 @@ def prune_configs(M, N, K, configs): BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") num_warps = config.get("num_warps") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - kpack = config.get("kpack") + # kpack = config.get("kpack") if matrix_instr_nonkdim > mfma: continue if mfma == 4 and BLOCK_SIZE_K < 64: continue # some layouts could not work properly in case - # number elemens per thread is less 1 + # number elements per thread is less 1 if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: continue - SPLIT_K = 1 #config.get("SPLIT_K") + SPLIT_K = 1 # config.get("SPLIT_K") GROUP_M = config.get("GROUP_SIZE_M") - if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): continue - if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: continue - if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough - if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: continue - if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: continue # skip large split_k when not necessary if SPLIT_K != 1 and not need_split_k(M, N, K): @@ -119,7 +147,8 @@ def prune_configs(M, N, K, configs): continue # out of shared memory resource # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm @@ -136,6 +165,7 @@ def prune_configs(M, N, K, configs): return pruned_configs + def union_of_list_of_dicts(l1, l2): result = [] l1.extend(l2) @@ -145,16 +175,18 @@ def union_of_list_of_dicts(l1, l2): return result + def need_split_k(SIZE_M, SIZE_N, SIZE_K): return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + def run_grid(bs, method): d_model = 4096 num_total_experts = 8 top_k = 2 tp_size = TP model_intermediate_size = 14336 - num_layers = 32 + # num_layers = 32 num_calls = 100 num_warmup_trials = 1 @@ -172,7 +204,8 @@ def run_grid(bs, method): prune_configs_2 = prune_configs(M2, N2, K2, full_configs) configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) - print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | {len(prune_configs_2)=} | {len(configs)=}") + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ + {len(prune_configs_2)=} | {len(configs)=}") best_config = None best_time_us = 1e20 @@ -210,7 +243,7 @@ def run_grid(bs, method): ) kernel_dur_us = 1000 * kernel_dur_ms - model_dur_ms = kernel_dur_ms * num_layers + # model_dur_ms = kernel_dur_ms * num_layers if kernel_dur_us < best_time_us: best_config = config @@ -237,9 +270,17 @@ def run_grid(bs, method): f.write("\n") -def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, - top_k: int, tp_size: int, model_intermediate_size: int, method, - config) -> float: +def run_timing( + num_calls: int, + bs: int, + d_model: int, + num_total_experts: int, + top_k: int, + tp_size: int, + model_intermediate_size: int, + method, + config, +) -> float: shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( @@ -260,13 +301,15 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, dtype=hidden_states.dtype, ) - gating_output = F.softmax(torch.rand( - # (num_calls, bs, num_total_experts), # THIS - (bs, num_total_experts), - device=hidden_states.device, - dtype=torch.float32, - ), - dim=-1) + gating_output = F.softmax( + torch.rand( + # (num_calls, bs, num_total_experts), # THIS + (bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1, + ) ###### Stuff from fused moe ######