diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 96bfde9655961..c972e1e4c8ba2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -450,9 +450,13 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity_structure = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value) + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + valid_compressors = { CompressionFormat.dense.value, CompressionFormat.sparse_24_bitmask.value diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index e9141d2a8f7f4..ea20f970e3752 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from compressed_tensors import CompressionFormat, ModelCompressor @@ -268,9 +268,9 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, ) return sparsity_compressor.decompress_weight(weight_data) - split_weights = None - split_bitmask = None - split_shape = None + split_weights: List[torch.Tensor] = [] + split_bitmask: List[torch.Tensor] = [] + split_shape: List[Tuple[int, int]] = [] if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, layer.logical_widths) @@ -278,7 +278,7 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, split_shape = [(out, layer.input_size_per_partition) for out in layer.logical_widths] - if split_weights is not None: + if split_weights: decompressed_shards = [ _process_split(compressed_weight, shape, bitmask) for compressed_weight, shape, bitmask in zip(