diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 399348300554a..5f84c7bc6c680 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -4,41 +4,50 @@ import sys import torch +import torch.distributed as dist +import torch.multiprocessing as mp import torch.nn.functional as F import triton import triton.language as tl from tqdm import tqdm -import torch.distributed as dist -import torch.multiprocessing as mp - -from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts) +from tuning_utils import ( + get_full_tuning_space, + prune_configs, + union_of_list_of_dicts, +) import vllm._moe_C as moe_kernels from vllm._C import ops -from vllm.model_executor.layers.fused_moe import (get_config_file_name, - invoke_fused_moe_kernel, - moe_align_block_size) +from vllm.model_executor.layers.fused_moe import ( + get_config_file_name, + invoke_fused_moe_kernel, + moe_align_block_size, +) os.environ["HIP_FORCE_DEV_KERNARG"] = "1" os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" + def main(args): world_size = args.numGPU - mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False) + mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False) + - def wrapper(rank, args): dist.init_process_group("nccl", world_size=args.numGPU, rank=rank) device_id = rank - - batches = [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096] + + batches = [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, + 3072, 4096 + ] for i in range(device_id, len(batches), args.numGPU): tune_batch(batches[i], model=args.model, TP=args.modelTP) -def tune_batch(bs, model, TP): +def tune_batch(bs, model, TP): device_id = torch.distributed.get_rank() - + if model == '8x7B': d_model = 4096 model_intermediate_size = 14336 @@ -69,7 +78,9 @@ def tune_batch(bs, model, TP): best_config = None best_time_us = 1e20 - progress_bar = tqdm(total=len(configs), desc=f"bs={bs:4d} device={device_id}", position=device_id) + progress_bar = tqdm(total=len(configs), + desc=f"bs={bs:4d} device={device_id}", + position=device_id) with torch.cuda.device(device_id): for config in configs: @@ -135,7 +146,7 @@ def run_timing( device_ = "cuda" dtype_ = torch.float16 - + hidden_states = torch.rand( (bs, d_model), device=device_, @@ -194,7 +205,7 @@ def run_timing( topk_weights, topk_ids, token_expert_indicies, - gating_output.float(), + gating_output.float(), ) del token_expert_indicies # Not used. Will be used in the future. diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py index 120e9de870867..ca75d59fe71ba 100644 --- a/benchmarks/kernels/tuning_utils.py +++ b/benchmarks/kernels/tuning_utils.py @@ -1,5 +1,3 @@ - - ## Utilize method from rocm/Triton tuning script def get_full_tuning_space(): configs = []