From 7716f476cb11fdfed1d0fb1d4234656a27a4a89a Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 5 Jun 2024 09:21:42 -0500 Subject: [PATCH 1/4] tuned moe configs v2 Update the fused moe config.json files. These config files utilizes all the available Triton kernel parameters for tuning. --- ...14336,device_name=AMD_Instinct_MI300X.json | 164 +++++++++++----- ...=1792,device_name=AMD_Instinct_MI300X.json | 182 +++++++++++++----- ...=3584,device_name=AMD_Instinct_MI300X.json | 172 ++++++++++++----- ...=7168,device_name=AMD_Instinct_MI300X.json | 176 ++++++++++++----- 4 files changed, 500 insertions(+), 194 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json index 93472eb08a462..6a976788f9b10 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -1,128 +1,200 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_stages": 1 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "256": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "512": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 }, "1536": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "2048": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "4096": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json index 5bd9d71e8f9bb..0a46390b2e31b 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -1,110 +1,200 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 32 + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8 + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "16": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1 + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "24": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1 - }, - "32": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "48": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1 + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32 + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "256": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "1024": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1 + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "2048": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "3072": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "4096": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json index 02e66280c1a3a..91011e64c7de4 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -1,128 +1,200 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 32, - "num_stages": 1 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "16": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_stages": 1 + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "24": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_stages": 0 + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, - "num_stages": 0 + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "96": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_stages": 0 + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "256": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "512": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 }, "1024": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "1536": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "2048": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "3072": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "4096": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json index 34c3b593d9799..f807d4a5abaed 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -1,128 +1,200 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, - "num_stages": 1 + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_stages": 1 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_stages": 0 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_stages": 0 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_stages": 1 + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "256": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 }, "512": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - "num_stages": 0 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "1536": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "2048": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 }, "3072": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 }, "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_stages": 0 + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 } } From 91acb711b0905e62b8a6f10eb97e5e76ffaa68b5 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Tue, 11 Jun 2024 21:31:43 +0000 Subject: [PATCH 2/4] add moe tuning script v2 --- .../kernels/benchmark_mixtral_moe_rocm.py | 372 ++++++++++++++++++ 1 file changed, 372 insertions(+) create mode 100755 benchmarks/kernels/benchmark_mixtral_moe_rocm.py diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py new file mode 100755 index 0000000000000..07eada1af18b8 --- /dev/null +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -0,0 +1,372 @@ +import json +import os +import sys + +import torch +import torch.nn.functional as F +import triton +import pandas as pd +from tqdm import tqdm +import vllm._moe_C as moe_kernels +from vllm._C import ops + +from vllm.model_executor.layers.fused_moe import (fused_moe, + get_config_file_name, + moe_align_block_size, + invoke_fused_moe_kernel) + +os.environ['CUDA_VISIBLE_DEVICES'] = '2' +os.environ['HIP_FORCE_DEV_KERNARG'] = '1' +os.environ['DEBUG_CLR_GRAPH_PACKET_CAPTURE'] = '1' +os.environ['OPTIMIZE_EPILOGUE'] = '1' + +TP = 8 + +def main(): + method = fused_moe + for bs in [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + 2048, 3072, 4096 + ]: + run_grid(bs, method=method) + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + # for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}) + + return configs + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + + if M < 32 or N < 32: + mfma = 16 + else: + mfma = 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >=2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elemens per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 #config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + continue + if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + continue + if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + continue + if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + +def union_of_list_of_dicts(l1, l2): + result = [] + l1.extend(l2) + for myDict in l1: + if myDict not in result: + result.append(myDict) + + return result + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + +def run_grid(bs, method): + d_model = 4096 + num_total_experts = 8 + top_k = 2 + tp_size = TP + model_intermediate_size = 14336 + num_layers = 32 + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + full_configs = get_full_tuning_space() + M1 = bs * 2 + N1 = model_intermediate_size * 2 // tp_size + K1 = 4096 + prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + + M2 = bs * 2 + N2 = 4096 + K2 = model_intermediate_size // tp_size + prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + + configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | {len(prune_configs_2)=} | {len(configs)=}") + + best_config = None + best_time_us = 1e20 + + for config in tqdm(configs): + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # benchmark + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + # print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + # f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + # f'{d_model=} {model_intermediate_size=} {num_layers=}') + + # print("best_time_us", best_time_us) + # print("best_config", best_config) + + # holds Dict[str, Dict[str, int]] + filename = get_config_file_name(num_total_experts, + model_intermediate_size // tp_size) + print(f"writing config to file {filename}") + existing_content = {} + if os.path.exists(filename): + with open(filename, "r") as f: + existing_content = json.load(f) + existing_content[str(bs)] = best_config + with open(filename, "w") as f: + json.dump(existing_content, f, indent=4) + f.write("\n") + + +def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, + top_k: int, tp_size: int, model_intermediate_size: int, method, + config) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + + hidden_states = torch.rand( + (bs, d_model), + device="cuda:0", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + gating_output = F.softmax(torch.rand( + # (num_calls, bs, num_total_experts), # THIS + (bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1) + + ###### Stuff from fused moe ###### + + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, _ = hidden_states.shape + E, N, _ = w1.shape + topk_ = 2 + topk_weights = torch.empty(M, + topk_, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk_, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk_, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E) + + ################################## + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for i in range(num_calls): + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + ) + + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + sys.exit(main()) From eb808433510d172078fe957243e71c6bff182788 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Tue, 11 Jun 2024 21:50:43 +0000 Subject: [PATCH 3/4] make ruff & yapf & isort happy --- .../kernels/benchmark_mixtral_moe_rocm.py | 125 ++++++++++++------ 1 file changed, 84 insertions(+), 41 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 07eada1af18b8..0661525e261e1 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -5,31 +5,48 @@ import torch import torch.nn.functional as F import triton -import pandas as pd from tqdm import tqdm + import vllm._moe_C as moe_kernels from vllm._C import ops - from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name, - moe_align_block_size, - invoke_fused_moe_kernel) + invoke_fused_moe_kernel, + moe_align_block_size) -os.environ['CUDA_VISIBLE_DEVICES'] = '2' -os.environ['HIP_FORCE_DEV_KERNARG'] = '1' -os.environ['DEBUG_CLR_GRAPH_PACKET_CAPTURE'] = '1' -os.environ['OPTIMIZE_EPILOGUE'] = '1' +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +os.environ["HIP_FORCE_DEV_KERNARG"] = "1" +os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" +os.environ["OPTIMIZE_EPILOGUE"] = "1" TP = 8 + def main(): method = fused_moe for bs in [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ]: run_grid(bs, method=method) + ## Utilize method from rocm/Triton tuning script def get_full_tuning_space(): configs = [] @@ -55,26 +72,36 @@ def get_full_tuning_space(): # for split_k in split_k_range: for num_stages in num_stage_range: for waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for (matrix_instr_nonkdim + ) in matrix_instr_nonkdim_range: for kpack in kpack_range: - configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}) + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": + matrix_instr_nonkdim, + "kpack": kpack, + }) return configs + ## Utilize method from rocm/Triton tuning script def prune_configs(M, N, K, configs): pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - if M < 32 or N < 32: - mfma = 16 - else: - mfma = 32 + mfma = 16 if M < 32 or N < 32 else 32 # TODO (zhanglx): figure out the boundary between large and small gemms large_gemm = False - if M >= 2048 and N >=2048: + if M >= 2048 and N >= 2048: large_gemm = True for config in configs: @@ -83,28 +110,29 @@ def prune_configs(M, N, K, configs): BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") num_warps = config.get("num_warps") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - kpack = config.get("kpack") + # kpack = config.get("kpack") if matrix_instr_nonkdim > mfma: continue if mfma == 4 and BLOCK_SIZE_K < 64: continue # some layouts could not work properly in case - # number elemens per thread is less 1 + # number elements per thread is less 1 if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: continue - SPLIT_K = 1 #config.get("SPLIT_K") + SPLIT_K = 1 # config.get("SPLIT_K") GROUP_M = config.get("GROUP_SIZE_M") - if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): continue - if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: continue - if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough - if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: continue - if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: continue # skip large split_k when not necessary if SPLIT_K != 1 and not need_split_k(M, N, K): @@ -119,7 +147,8 @@ def prune_configs(M, N, K, configs): continue # out of shared memory resource # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm @@ -136,6 +165,7 @@ def prune_configs(M, N, K, configs): return pruned_configs + def union_of_list_of_dicts(l1, l2): result = [] l1.extend(l2) @@ -145,16 +175,18 @@ def union_of_list_of_dicts(l1, l2): return result + def need_split_k(SIZE_M, SIZE_N, SIZE_K): return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + def run_grid(bs, method): d_model = 4096 num_total_experts = 8 top_k = 2 tp_size = TP model_intermediate_size = 14336 - num_layers = 32 + # num_layers = 32 num_calls = 100 num_warmup_trials = 1 @@ -172,7 +204,8 @@ def run_grid(bs, method): prune_configs_2 = prune_configs(M2, N2, K2, full_configs) configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) - print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | {len(prune_configs_2)=} | {len(configs)=}") + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ + {len(prune_configs_2)=} | {len(configs)=}") best_config = None best_time_us = 1e20 @@ -210,7 +243,7 @@ def run_grid(bs, method): ) kernel_dur_us = 1000 * kernel_dur_ms - model_dur_ms = kernel_dur_ms * num_layers + # model_dur_ms = kernel_dur_ms * num_layers if kernel_dur_us < best_time_us: best_config = config @@ -237,9 +270,17 @@ def run_grid(bs, method): f.write("\n") -def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, - top_k: int, tp_size: int, model_intermediate_size: int, method, - config) -> float: +def run_timing( + num_calls: int, + bs: int, + d_model: int, + num_total_experts: int, + top_k: int, + tp_size: int, + model_intermediate_size: int, + method, + config, +) -> float: shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( @@ -260,13 +301,15 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, dtype=hidden_states.dtype, ) - gating_output = F.softmax(torch.rand( - # (num_calls, bs, num_total_experts), # THIS - (bs, num_total_experts), - device=hidden_states.device, - dtype=torch.float32, - ), - dim=-1) + gating_output = F.softmax( + torch.rand( + # (num_calls, bs, num_total_experts), # THIS + (bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1, + ) ###### Stuff from fused moe ###### From fd0466171c631b05ec5cb154aef2219f4a97e86b Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Thu, 13 Jun 2024 14:29:13 +0000 Subject: [PATCH 4/4] [nit] update __init__ for tuning script --- vllm/model_executor/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index fb98f4a6b46f4..76523845063ed 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,7 +1,10 @@ -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_moe, get_config_file_name, invoke_fused_moe_kernel, + moe_align_block_size) __all__ = [ - "SamplingMetadata", - "set_random_seed", + "fused_moe", + "get_config_file_name", + "moe_align_block_size", + "invoke_fused_moe_kernel", ]