Skip to content

Commit

Permalink
make ruff & yapf & isort happy
Browse files Browse the repository at this point in the history
  • Loading branch information
divakar-amd committed Jun 11, 2024
1 parent 91acb71 commit eb80843
Showing 1 changed file with 84 additions and 41 deletions.
125 changes: 84 additions & 41 deletions benchmarks/kernels/benchmark_mixtral_moe_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,48 @@
import torch
import torch.nn.functional as F
import triton
import pandas as pd
from tqdm import tqdm

import vllm._moe_C as moe_kernels
from vllm._C import ops

from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name,
moe_align_block_size,
invoke_fused_moe_kernel)
invoke_fused_moe_kernel,
moe_align_block_size)

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ['HIP_FORCE_DEV_KERNARG'] = '1'
os.environ['DEBUG_CLR_GRAPH_PACKET_CAPTURE'] = '1'
os.environ['OPTIMIZE_EPILOGUE'] = '1'
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1"
os.environ["OPTIMIZE_EPILOGUE"] = "1"

TP = 8


def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]:
run_grid(bs, method=method)


## Utilize method from rocm/Triton tuning script
def get_full_tuning_space():
configs = []
Expand All @@ -55,26 +72,36 @@ def get_full_tuning_space():
# for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for (matrix_instr_nonkdim
) in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack})
configs.append({
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_m,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim":
matrix_instr_nonkdim,
"kpack": kpack,
})

return configs


## Utilize method from rocm/Triton tuning script
def prune_configs(M, N, K, configs):
pruned_configs = []
elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes)

if M < 32 or N < 32:
mfma = 16
else:
mfma = 32
mfma = 16 if M < 32 or N < 32 else 32

# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >=2048:
if M >= 2048 and N >= 2048:
large_gemm = True

for config in configs:
Expand All @@ -83,28 +110,29 @@ def prune_configs(M, N, K, configs):
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
kpack = config.get("kpack")
# kpack = config.get("kpack")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elemens per thread is less 1
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = 1 #config.get("SPLIT_K")
SPLIT_K = 1 # config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim:
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim:
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim:
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16:
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
Expand All @@ -119,7 +147,8 @@ def prune_configs(M, N, K, configs):
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
Expand All @@ -136,6 +165,7 @@ def prune_configs(M, N, K, configs):

return pruned_configs


def union_of_list_of_dicts(l1, l2):
result = []
l1.extend(l2)
Expand All @@ -145,16 +175,18 @@ def union_of_list_of_dicts(l1, l2):

return result


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
tp_size = TP
model_intermediate_size = 14336
num_layers = 32
# num_layers = 32
num_calls = 100

num_warmup_trials = 1
Expand All @@ -172,7 +204,8 @@ def run_grid(bs, method):
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)

configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | {len(prune_configs_2)=} | {len(configs)=}")
print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
{len(prune_configs_2)=} | {len(configs)=}")

best_config = None
best_time_us = 1e20
Expand Down Expand Up @@ -210,7 +243,7 @@ def run_grid(bs, method):
)

kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
# model_dur_ms = kernel_dur_ms * num_layers

if kernel_dur_us < best_time_us:
best_config = config
Expand All @@ -237,9 +270,17 @@ def run_grid(bs, method):
f.write("\n")


def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float:
def run_timing(
num_calls: int,
bs: int,
d_model: int,
num_total_experts: int,
top_k: int,
tp_size: int,
model_intermediate_size: int,
method,
config,
) -> float:
shard_intermediate_size = model_intermediate_size // tp_size

hidden_states = torch.rand(
Expand All @@ -260,13 +301,15 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
dtype=hidden_states.dtype,
)

gating_output = F.softmax(torch.rand(
# (num_calls, bs, num_total_experts), # THIS
(bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1)
gating_output = F.softmax(
torch.rand(
# (num_calls, bs, num_total_experts), # THIS
(bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1,
)

###### Stuff from fused moe ######

Expand Down

0 comments on commit eb80843

Please sign in to comment.