diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index d410e64cf..e2346a2fe 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -311,16 +311,16 @@ def update_config(self, save_directory: str): with open(config_file_path, "r") as config_file: config_data = json.load(config_file) - config_data[COMPRESSION_CONFIG_NAME] = {} + config_data[QUANTIZATION_CONFIG_NAME] = {} if self.quantization_config is not None: quant_config_data = self.quantization_config.model_dump() - config_data[COMPRESSION_CONFIG_NAME] = quant_config_data + config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data if self.sparsity_config is not None: sparsity_config_data = self.sparsity_config.model_dump() - config_data[COMPRESSION_CONFIG_NAME][ + config_data[QUANTIZATION_CONFIG_NAME][ SPARSITY_CONFIG_NAME ] = sparsity_config_data - config_data[COMPRESSION_CONFIG_NAME][ + config_data[QUANTIZATION_CONFIG_NAME][ COMPRESSION_VERSION_NAME ] = compressed_tensors.__version__