Skip to content

Commit

Permalink
Conditional requantization and assert on padding in block quant
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Jan 16, 2025
1 parent 6747657 commit 73aedac
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def create_weights(
weight_loader = extra_weight_attrs.get("weight_loader")

if self.block_quant:
assert not envs.VLLM_FP8_PADDING, "FP8 weight padding is not supported in block quantization."

Check failure on line 169 in vllm/model_executor/layers/quantization/fp8.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/quantization/fp8.py:169:81: E501 Line too long (106 > 80)
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
Expand Down Expand Up @@ -253,14 +254,14 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:

weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
if current_platform.is_rocm() and not is_navi():
weight, weight_scale, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)

Check failure on line 264 in vllm/model_executor/layers/quantization/fp8.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/quantization/fp8.py:264:81: E501 Line too long (97 > 80)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
Expand Down Expand Up @@ -521,29 +522,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(w2_weight_scale_inv,
if current_platform.is_rocm() and not is_navi():
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(w2_weight_scale_inv,

Check failure on line 544 in vllm/model_executor/layers/quantization/fp8.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/quantization/fp8.py:544:81: E501 Line too long (83 > 80)
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down

0 comments on commit 73aedac

Please sign in to comment.