Skip to content

Commit

Permalink
Actually requantizing moe weights
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Jan 15, 2025
1 parent fc59c22 commit 6747657
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,29 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(w2_weight_scale_inv,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down

0 comments on commit 6747657

Please sign in to comment.