Skip to content

Commit

Permalink
update intermediate_size name
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jan 18, 2025
1 parent c2bce52 commit 265e99a
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 108 deletions.
31 changes: 17 additions & 14 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):

@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError

Expand All @@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
Expand Down Expand Up @@ -292,14 +294,15 @@ def __init__(
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
"intermediate_size": self.intermediate_size_per_partition,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_full"] = intermediate_size
moe_quant_params["intermediate_size_full"] = intermediate_size

self.quant_method.create_weights(layer=self, **moe_quant_params)

Expand Down Expand Up @@ -438,15 +441,15 @@ def weight_loader(self, param: torch.nn.Parameter,
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
Expand Down
35 changes: 19 additions & 16 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
Expand All @@ -312,17 +312,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
FusedMoeWeightScaleSupported.GROUP.value,
})

w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qweight = Parameter(
torch.empty(num_experts,
hidden_size,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)

w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
intermediate_size_per_partition,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
Expand All @@ -331,13 +332,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
set_weight_attrs(w2_qweight, extra_weight_attrs)

num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size
num_groups_w2 = (intermediate_size_per_partition //
self.quant_config.group_size)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
intermediate_size_per_partition * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
Expand All @@ -353,12 +355,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,

# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qzeros = Parameter(
torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,26 @@ def __init__(
self.static_input_scales = not self.input_quant.dynamic

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

params_dtype = torch.float8_e4m3fn

# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
Expand Down Expand Up @@ -268,47 +270,49 @@ def __init__(
f"{WNA16_SUPPORTED_BITS}")

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

assert params_dtype == torch.float16, (
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
)

intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")

# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
intermediate_full = extra_weight_attrs.pop("intermediate_full")
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size //
self.packed_factor,
2 * intermediate_size,
dtype=torch.int32),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(torch.empty(num_experts,
intermediate_size //
self.packed_factor,
hidden_size,
dtype=torch.int32),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (intermediate_full
if load_full_w2 else intermediate_size)
w2_scales_size = (intermediate_size_full
if load_full_w2 else intermediate_size_per_partition)

self.is_k_full = (not self.actorder) or (intermediate_size
== intermediate_full)
self.is_k_full = (not self.actorder) or (
intermediate_size_per_partition == intermediate_size_full)

if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
Expand All @@ -317,10 +321,11 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size

w13_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w13,
2 * intermediate_size,
dtype=params_dtype),
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
Expand Down Expand Up @@ -358,7 +363,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
Expand All @@ -381,7 +386,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
Expand Down
27 changes: 15 additions & 12 deletions vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, quant_config: ExpertsInt8Config):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

int8_dtype = torch.int8
Expand All @@ -64,26 +64,29 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
extra_weight_attrs['weight_loader'] = wrapped_weight_loader

# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=int8_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=int8_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
2 * intermediate_size,
dtype=torch.float32),
w13_scale = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)

Expand Down
Loading

0 comments on commit 265e99a

Please sign in to comment.