Skip to content

Commit

Permalink
Start working on fused moe cutlass implementation
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola committed Jan 17, 2025
1 parent c6231b6 commit f1a5666
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 9 deletions.
196 changes: 190 additions & 6 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, fused_topk, fused_experts)

DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"]
# "nm-testing/deepseekv2-lite",
Expand All @@ -17,6 +19,11 @@
PER_ACT_TOKEN_OPTS = [False, True]
PER_OUT_CH_OPTS = [False, True]

def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)

def grouped_gemm(a_g_tensors: List[torch.Tensor],
b_g_tensors: List[torch.Tensor],
out_g_tensors: List[torch.Tensor],
Expand All @@ -33,6 +40,30 @@ def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor],
b = b_tensors[g]
out = torch.mm(a, b)
out_tensors[g] = out

def cutlass_fused(a_tensors: List[torch.Tensor],
w1_tensors: List[torch.Tensor],
w2_tensors: List[torch.Tensor],
c1_tensors: List[torch.Tensor],
c2_tensors: List[torch.Tensor],
c2_tensors_fp8: List[torch.Tensor],
c3_tensors: List[torch.Tensor],
a_scales: List[torch.Tensor],
w1_scales: List[torch.Tensor],
w2_scales: List[torch.Tensor],
c2_scales: List[torch.Tensor],
num_groups: int):
# output_dtype = c3_tensors[0].dtype
N = c2_tensors[0].shape[1]
ops.cutlass_grouped_mm(c1_tensors, a_tensors, w1_tensors,
a_scales, w1_scales)
# TODO make this work as it should
for idx in range(num_groups):
torch.ops._C.silu_and_mul(c2_tensors[idx], c1_tensors[idx].view(-1, N))
print(c2_tensors[idx])
c2_tensors_fp8[idx] = to_fp8(c2_tensors[idx].half())
ops.cutlass_grouped_mm(c3_tensors, c2_tensors, w2_tensors,
c2_scales, w2_scales)

def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int,
per_act_token: bool, per_out_ch: bool,
Expand All @@ -47,11 +78,6 @@ def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int,

device = "cuda"
out_dtype = torch.half

def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)

a_tensors = []
b_tensors = []
Expand Down Expand Up @@ -142,6 +168,164 @@ def to_fp8(tensor: torch.Tensor):
description="baseline_gemm",
).blocked_autorange(min_run_time=min_run_time))

def bench_run_moe(results: List[benchmark.Measurement], model: str, num_groups: int,

Check failure on line 171 in benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py:171:81: E501 Line too long (84 > 80)
per_act_token: bool, per_out_ch: bool,
mkn: List[Tuple[int, int, int]]):
label = "Quant Matmul"

sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_groups, per_act_token,
per_out_ch, mkn))

print(f"Testing: {sub_label}")

device = "cuda"
out_dtype = torch.bfloat16

def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)

m_tot = sum([elem[0] for elem in mkn])
k_g = mkn[0][1]
n_g = mkn[0][2]

a_tensors = []
w1_tensors = []
w2_tensors = []
c1_tensors = []
c2_tensors = []
c2_tensors_fp8 = []
c3_tensors = []
a_scales = []
w1_scales = []
w2_scales = []
c2_scales = []

a = torch.randn((m_tot, k_g), device=device, dtype=out_dtype)
w1 = torch.randn((num_groups, 2 * n_g, k_g), device=device, dtype=out_dtype)
w2 = torch.randn((num_groups, k_g, n_g), device=device, dtype=out_dtype)
scored_output = torch.randn((m_tot, num_groups), device="cuda", dtype=out_dtype)

Check failure on line 209 in benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py:209:81: E501 Line too long (84 > 80)
topk = 2
# triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)

#TODO grouped topk for deepseek
topk_weights, topk_ids = fused_topk(a, scored_output, topk, renormalize=True)

Check failure on line 214 in benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py:214:81: E501 Line too long (81 > 80)
fused_experts(a, w1, w2, topk_weights, topk_ids)
topk_ids_cpu = topk_ids.cpu()

occurrences = [0] * num_groups
expert_offsets = [0] * (num_groups + 1)
for id in topk_ids_cpu.flatten():
occurrences[id] += 1

for e in range(num_groups):
expert_offsets[e + 1] = expert_offsets[e] + occurrences[e]

print(expert_offsets, m_tot)

a = torch.randn((m_tot, k_g))
a_group[0] = a[sorted_token_ids[0]]

Check failure on line 229 in benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F821)

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py:229:5: F821 Undefined name `a_group`

Check failure on line 229 in benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F821)

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py:229:20: F821 Undefined name `sorted_token_ids`

# TODO
# create full input tensor m_tot x k_g x topk
# get shuffle data like sorted_token_ids etc.
# create view

for g in range(num_groups):
m_g = occurrences[g]
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
w1_g = to_fp8(torch.randn((2 * n_g, k_g), device=device).t())
w2_g = to_fp8(torch.randn((k_g, n_g), device=device).t())
c1_g = torch.zeros((m_g, 2 * n_g), device=device, dtype=torch.bfloat16)
c2_g = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16)
c2_g_fp8 = to_fp8(torch.zeros((m_g, n_g), device=device))
c3_g = torch.zeros((m_g, k_g), device=device, dtype=torch.bfloat16)
# m_a_scales = m_g if per_act_token else 1
# n_b_scales = n_g if per_out_ch else 1
m_scales = 1
n2_scales = 1
k_scales = 1
scale_a = (torch.randn((m_scales, 1), device=device,
dtype=torch.float32))
scale_w1 = (torch.randn((n2_scales, 1), device=device,
dtype=torch.float32))
scale_w2 = (torch.randn((k_scales, 1), device=device,
dtype=torch.float32))
scale_c2 = (torch.randn((m_scales, 1), device=device,
dtype=torch.float32))

a_tensors.append(a_g)
w1_tensors.append(w1_g)
w2_tensors.append(w2_g)
c1_tensors.append(c1_g)
c2_tensors.append(c2_g)
c2_tensors_fp8.append(c2_g_fp8)
c3_tensors.append(c3_g)
a_scales.append(scale_a)
w1_scales.append(scale_w1)
w2_scales.append(scale_w2)
c2_scales.append(scale_c2)

globals = {
# Gen params
"num_groups": num_groups,
# Grouped gemm params
"a_tensors": a_tensors,
"w1_tensors": w1_tensors,
"w2_tensors": w2_tensors,
"c1_tensors": c1_tensors,
"c2_tensors": c2_tensors,
"c2_tensors_fp8": c2_tensors_fp8,
"c3_tensors": c3_tensors,
"a_scales": a_scales,
"w1_scales": w1_scales,
"w2_scales": w2_scales,
"c2_scales": c2_scales,
# Triton params (fused_moe)
"a": a,
"w1": w1,
"w2": w2,
"scored_output": scored_output,
"topk": topk,
# Kernels
"fused_moe": fused_moe,
"cutlass_fused": cutlass_fused,
}

min_run_time = 1
num_warmup = 5

# Warmup triton
for _ in range(num_warmup):
fused_moe(a, w1, w2, scored_output, topk, renormalize=False)

results.append(
benchmark.Timer(
stmt="fused_moe(a, w1, w2, scored_output, topk, renormalize=False)",
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm",
).blocked_autorange(min_run_time=min_run_time))

# Warmup cutlass
for _ in range(num_warmup):
cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors,
c2_tensors_fp8, c3_tensors, a_scales, w1_scales,
w2_scales, c2_scales, num_groups)

results.append(
benchmark.Timer(
stmt=
"cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, c2_tensors_fp8, c3_tensors, a_scales, w1_scales, w2_scales, c2_scales, num_groups)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="baseline_gemm",
).blocked_autorange(min_run_time=min_run_time))

def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
Expand All @@ -165,7 +349,7 @@ def main(args):
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in DEFAULT_BATCH_SIZES:
mkn = [(size_m, size_k, size_n)] * num_groups
bench_run(results, model, num_groups, per_act_token,
bench_run_moe(results, model, num_groups, per_act_token,
per_out_ch, mkn)

compare = benchmark.Compare(results)
Expand Down
8 changes: 8 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor out_offsets, Tensor a_offsets, "
" Tensor b_offsets) -> ()");
ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);

ops.def(
"compute_expert_offsets(Tensor! trg_a_ptrs,"
" Tensor! a, Tensor topk_ids,"
" Tensor! expert_offsets, SymInt num_experts) -> ()");
ops.impl("compute_expert_offsets", torch::kCUDA,
&compute_expert_offsets);

// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ struct Sm90ColOrScalarBroadcast {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;

if (threadIdx.x ==128){
printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l);
}
// if (threadIdx.x ==128){
// printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l);
// }
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Expand Down
6 changes: 6 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ void cutlass_grouped_mm(c10::List<at::Tensor> const& out_tensors,
c10::List<at::Tensor> const& a_scales,
c10::List<at::Tensor> const& b_scales);

void compute_expert_offsets(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
47 changes: 47 additions & 0 deletions csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,50 @@ void cutlass_grouped_mm_sm90(c10::List<at::Tensor> const& out_tensors,
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales);
}

__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,
cutlass::float_e4m3_t* base_a_ptr,
const int* __restrict__ topk_ids,
int64_t* expert_offsets,
int topk_length) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;

int occurrences = 0;
for (int i = 0; i < topk_length; ++i) {
occurrences += (topk_ids[i] == expert_id);
}
expert_offsets[expert_id + 1] = occurrences;
__syncthreads();

if (threadIdx.x == 0) {
int64_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
trg_a_ptrs[i] = base_a_ptr + tot_offset;
tot_offset += expert_offsets[i + 1];
expert_offsets[i + 1] = tot_offset;
}
}
}

// For a given "a" of size [M,K] performs a permutation of the M rows based
// on the given "perm" indices.
__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr,
int const* __restrict__ perm_int_ptr,
cutlass::float_e4m3_t* __restrict__ out_ptr,
int size_m, int size_k, int block_rows) {
// TODO
}

void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts) {
get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(),
(cutlass::float_e4m3_t*)a.data_ptr(),
(const int*)topk_ids.data_ptr(),
(int64_t*)expert_offsets.data_ptr(),
topk_ids.numel());
}
16 changes: 16 additions & 0 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ void cutlass_grouped_mm_sm90(c10::List<at::Tensor> const& out_tensors,
c10::List<at::Tensor> const& a_scales,
c10::List<at::Tensor> const& b_scales);


void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts);

#endif

void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
Expand Down Expand Up @@ -159,6 +166,15 @@ void cutlass_grouped_mm(c10::List<at::Tensor> const& out_tensors,
b_scales);
}

void compute_expert_offsets(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts) {
compute_expert_offsets_caller(trg_a_ptrs, a, topk_ids, expert_offsets,
num_experts);
}

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
7 changes: 7 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor[] b_scales) -> ()");
ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);

ops.def(
"compute_expert_offsets(Tensor! trg_a_ptrs,"
" Tensor! a, Tensor topk_ids,"
" Tensor! expert_offsets, SymInt num_experts) -> ()");
ops.impl("compute_expert_offsets", torch::kCUDA,
&compute_expert_offsets);

// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops.def(
Expand Down
Loading

0 comments on commit f1a5666

Please sign in to comment.