From 4bcfde6bd0954e1cf2014edd48dd0914902f6a40 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 25 Jun 2024 06:43:33 +0000 Subject: [PATCH] Some renaming --- vllm/model_executor/layers/fused_moe/fused_moe.py | 1 + .../layers/quantization/utils/marlin_utils.py | 2 +- vllm/model_executor/models/mixtral_quant.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a18b32694d257..5c5a662ec75d7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -331,6 +331,7 @@ def get_default_config( } return config + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index a2287072fe94d..43895e4696918 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -33,7 +33,7 @@ def get_scale_perms(num_bits): [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales_2(s, size_k, size_n, group_size, num_bits): +def marlin_permute_scales_numbits(s, size_k, size_n, group_size, num_bits): scale_perm, scale_perm_single = get_scale_perms(num_bits) if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 5935cc433de48..e2c52911ebc82 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -53,7 +53,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput -from vllm.model_executor.layers.quantization.utils.marlin_utils import marlin_permute_scales_2 +from vllm.model_executor.layers.quantization.utils.marlin_utils import marlin_permute_scales_numbits class MixtralMLP(nn.Module): @@ -158,7 +158,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: g_idx_sort_idx = torch.empty(0, dtype=torch.int, device=w13_qw.device) w13_qw = ops.gptq_marlin_repack(w13_qw, g_idx_sort_idx, size_k, size_n, self.quant_config.weight_bits) - w13_s = marlin_permute_scales_2(w13_s, size_k, size_n, + w13_s = marlin_permute_scales_numbits(w13_s, size_k, size_n, self.quant_config.group_size, self.quant_config.weight_bits) @@ -166,7 +166,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: size_n = w2_qw.shape[1] w2_qw = ops.gptq_marlin_repack(w2_qw, g_idx_sort_idx, size_k, size_n, self.quant_config.weight_bits) - w2_s = marlin_permute_scales_2(w2_s, size_k, size_n, + w2_s = marlin_permute_scales_numbits(w2_s, size_k, size_n, self.quant_config.group_size, self.quant_config.weight_bits)