Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable fused topK_softmax kernel for hip #14

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width)
#endif

#ifndef USE_ROCM
Expand Down
27 changes: 17 additions & 10 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "csrc/cuda_compat.h"

#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))

namespace vllm {
namespace moe {

static constexpr int WARP_SIZE = 32;

/// Aligned array type
template <
typename T,
Expand Down Expand Up @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
}

// From this point, thread max in all the threads have the max within the row.
Expand All @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
}

// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
Expand Down Expand Up @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);

// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
Expand Down Expand Up @@ -383,7 +390,7 @@ struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
gshtras marked this conversation as resolved.
Show resolved Hide resolved
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
Expand All @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;

static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,13 @@ def _read_requirements(filename: str) -> List[str]:

ext_modules = []

if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
if _is_cuda() and _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))

if not _is_neuron():
ext_modules.append(CMakeExtension(name="vllm._C"))
ext_modules.append(CMakeExtension(name="vllm._custom_C"))
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
Expand Down
195 changes: 115 additions & 80 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import triton
import triton.language as tl

import vllm._moe_C as moe_kernels
from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

Expand Down Expand Up @@ -108,8 +108,8 @@ def fused_moe_kernel(
offs_k[None, :] * stride_ak)

off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
b_ptrs = (b_ptr + off_experts * stride_be +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand All @@ -121,10 +121,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
Expand All @@ -144,8 +146,8 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_ptrs = (c_ptr + stride_cm * offs_token[:, None] +
stride_cn * offs_cn[None, :])
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

Expand Down Expand Up @@ -193,31 +195,46 @@ def moe_align_block_size(
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device)
device=topk_ids.device,
)
expert_ids = torch.empty(
(topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device,
)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
return sorted_ids, expert_ids, num_tokens_post_pad


def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
"BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), )

fused_moe_kernel[grid](
A,
Expand Down Expand Up @@ -310,8 +327,8 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert (hidden_states.shape[0] == gating_output.shape[0]
), "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
Expand All @@ -323,34 +340,26 @@ def fused_moe(
M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk_weights = torch.empty(M,
topk,
dtype=torch.int32,
dtype=torch.float32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

Expand All @@ -367,48 +376,74 @@ def fused_moe(
else:
# Else use the default config
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}

if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}

intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
topk_ids, config["BLOCK_SIZE_M"], E)

invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
)

if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
Loading