diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index bfbceb24aef96..98a66b6701ea9 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -15,4 +15,6 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main -fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main \ No newline at end of file +fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main +marlin, nm-testing/zephyr-beta-7b-marlin-g128, main +marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4af954b74e8b5..e5b40a64abc41 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -22,7 +22,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", - "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod" + "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", + "MarlinLinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index cdc5129a93b15..8f1b5370b4538 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -9,7 +9,10 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) logger = init_logger(__name__) @@ -132,6 +135,7 @@ def create_weights( **extra_weight_attrs, ): del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( @@ -170,64 +174,64 @@ def create_weights( "Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.tile_size, output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - "marlin_tile_size": self.quant_config.tile_size, - }, - ) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) # Determine if channelwise or not input_groups = (1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size) - scales = Parameter( + weight_scale_args = { + "data": torch.empty( input_groups, output_size_per_partition, device="cuda", dtype=params_dtype, ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - "input_dim": None if input_groups == 1 else 0, - "output_dim": 1, - }, - ) + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( output_size_per_partition // self.quant_config.min_n_threads) * self.quant_config.max_parallel - workspace = Parameter(torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - requires_grad=False) + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) layer.register_parameter("B", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("s", scales) - set_weight_attrs(scales, extra_weight_attrs) layer.register_parameter("workspace", workspace) - set_weight_attrs(workspace, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) def apply( self,