Skip to content

Commit

Permalink
Implement custom moe op for Mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
jbyczkow committed Sep 24, 2024
1 parent 84b2490 commit 5a271c8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
48 changes: 42 additions & 6 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@ def forward(self, state):
return torch.matmul(state, self.weight)


def calculate_routing_tensors(score, topk, hidden_states_dtype):
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states_dtype)
return routing_weights, selected_experts


class StaticFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
Expand All @@ -263,12 +273,8 @@ def __init__(self, num_total_experts):

def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
routing_weights, selected_experts = calculate_routing_tensors(
score, topk, hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
Expand All @@ -291,3 +297,33 @@ def forward(self, hidden_states, w1, w2, score, topk):
final_hidden_states += current_hidden_states_static

return final_hidden_states.view(-1, D)


class DynamicFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
super().__init__()
self.num_total_experts = num_total_experts

def forward(self, hidden_states, w1, w2, score, topk):
htorch.core.mark_step()
routing_weights, selected_experts = calculate_routing_tensors(
score, topk, hidden_states.dtype)
# pre-processing for custom op inputs
experts_range = range(self.num_total_experts)
w1_list = [w1[i,:,:].squeeze() for i in experts_range]
w2_list = [w2[i,:,:].squeeze() for i in experts_range]

final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=True,
activation="silu",
experts_min=0,
experts_max=7
)

return final_hidden_states.view(-1, hidden_states.shape[1])
17 changes: 12 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,14 @@ def __init__(
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if is_hpu():
from vllm.hpu.ops import StaticFusedMOE
self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts)
from vllm.hpu.ops import StaticFusedMOE, DynamicFusedMOE
from vllm.model_executor.layers.quantization.inc import INCConfig
selected_fused_moe = (
StaticFusedMOE
if isinstance(quant_config, INCConfig)
else DynamicFusedMOE
)
self.hpu_static_fused_moe = selected_fused_moe(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -254,24 +260,25 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_size = self.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)

from vllm.hpu.ops import StaticFusedMOE
# w1, gate_proj case: Load into first shard of w13.
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE):
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
if is_hpu():
if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE):
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE):
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(
param_data[expert_id])
else:
Expand Down

0 comments on commit 5a271c8

Please sign in to comment.