Skip to content

Commit

Permalink
ruff & yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
divakar-amd committed Aug 16, 2024
1 parent 49473c2 commit cd68e07
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
43 changes: 27 additions & 16 deletions benchmarks/kernels/benchmark_mixtral_moe_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,50 @@
import sys

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import triton
import triton.language as tl
from tqdm import tqdm
import torch.distributed as dist
import torch.multiprocessing as mp

from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts)
from tuning_utils import (
get_full_tuning_space,
prune_configs,
union_of_list_of_dicts,
)

import vllm._moe_C as moe_kernels
from vllm._C import ops
from vllm.model_executor.layers.fused_moe import (get_config_file_name,
invoke_fused_moe_kernel,
moe_align_block_size)
from vllm.model_executor.layers.fused_moe import (
get_config_file_name,
invoke_fused_moe_kernel,
moe_align_block_size,
)

os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1"


def main(args):
world_size = args.numGPU
mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False)
mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False)



def wrapper(rank, args):
dist.init_process_group("nccl", world_size=args.numGPU, rank=rank)
device_id = rank

batches = [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096]

batches = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048,
3072, 4096
]
for i in range(device_id, len(batches), args.numGPU):
tune_batch(batches[i], model=args.model, TP=args.modelTP)


def tune_batch(bs, model, TP):
def tune_batch(bs, model, TP):
device_id = torch.distributed.get_rank()

if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
Expand Down Expand Up @@ -69,7 +78,9 @@ def tune_batch(bs, model, TP):
best_config = None
best_time_us = 1e20

progress_bar = tqdm(total=len(configs), desc=f"bs={bs:4d} device={device_id}", position=device_id)
progress_bar = tqdm(total=len(configs),
desc=f"bs={bs:4d} device={device_id}",
position=device_id)

with torch.cuda.device(device_id):
for config in configs:
Expand Down Expand Up @@ -135,7 +146,7 @@ def run_timing(

device_ = "cuda"
dtype_ = torch.float16

hidden_states = torch.rand(
(bs, d_model),
device=device_,
Expand Down Expand Up @@ -194,7 +205,7 @@ def run_timing(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
gating_output.float(),
)
del token_expert_indicies # Not used. Will be used in the future.

Expand Down
2 changes: 0 additions & 2 deletions benchmarks/kernels/tuning_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


## Utilize method from rocm/Triton tuning script
def get_full_tuning_space():
configs = []
Expand Down

0 comments on commit cd68e07

Please sign in to comment.