Skip to content

Latest commit

 

History

History
126 lines (99 loc) · 5.02 KB

README.md

File metadata and controls

126 lines (99 loc) · 5.02 KB

Spiking Diffusion Models
(IEEE Transactions on Artificial Intelligence)

Jiahang Cao1*, Hanzhong Guo2*, Ziqing Wang3*, Deming Zhou1, Hao Cheng1, Qiang Zhang1, Renjing Xu1†

1The Hong Kong University of Science and Technology (Guangzhou)
2The Hong Kong University 3Northwestern University

Paper PDF Proceeding Supp

This work SDM is an extended version of SDDPM. We introduce several key improvements:

  • A New Family of Spiking-based Diffusion Models:This work extends applicability to a wider array of diffusion solvers, including but not limited to DDPM, DDIM, Analytic-DPM and DPM-Solver.
  • Biologically Inspired Temporal-wise Spiking Mechanism (TSM): Inspired by biological processes that the neuron input at each moment experiences considerable fluctuations rather than being predominantly controlled by fixed synaptic weights, this module enables spiking neurons to capture more dynamic information. The TSM module can be integrated with existing modules (proposed by SDDPM) to further improve the image generation quality.
  • ANN-SNN Conversion for SDM: To the best of our knowledge, we make the first attempt to utilize an ANN-SNN approach for implementing spiking diffusion models, complete with theoretical foundations.

Logo

Requirements

Please see SDDPM.

TSM Finetune

Here we provide an example code to finetune the SDM models by inheriting the weights obtained from SDDPM pre-training:

from TSM import Spk_UNet_TSM

... (First, pretrain the standard SNN UNet)

pretrained_model = Spk_UNet(
      T=args.T, ch=args.ch, ch_mult=args.ch_mult, attn=args.attn,
      num_res_blocks=args.num_res_blocks, dropout=args.dropout, timestep=args.timestep, img_ch=args.img_ch)
  
# Load model
ckpt = torch.load(os.path.join('/your/pretrained_checkpoint'))
pretrained_model.load_state_dict(ckpt['net_model'], strict=True)
pretrained_dict = pretrained_model.state_dict()

net_model = Spk_UNet_TSM(
    T=args.T, ch=args.ch, ch_mult=args.ch_mult, attn=args.attn,
    num_res_blocks=args.num_res_blocks, dropout=args.dropout, timestep=args.timestep, img_ch=args.img_ch)

model_dict = net_model.state_dict()
new_state_dict = OrderedDict()

for name,para in pretrained_dict.items():
    if name in model_dict:
        new_state_dict[name] = para
    
    elif 'conv' and 'weight' in name:
        head = name[:-7]
        new_name = head + '.tsmconv.weight'
        new_state_dict[new_name] = para
        
    elif 'conv' and 'bias' in name:
        head = name[:-5]
        new_name = head + '.tsmconv.bias'
        new_state_dict[new_name] = para
       
net_model.load_state_dict(new_state_dict, strict=False)

print(f'-------Successfully inherit pretrained weights-------')

...(Next, finetune the TSM SDM with the same training code from SDDPM)

Sample

Example codes for sampling the images with DDIM solver.

The checkpoint of SDM with snn_timesteps=8 in CIFAR-10 is released. You can download the checkpoint through this link.

cd SDM
CUDA_VISIBLE_DEVICES=0 python sample.py

Citation

If you find our work useful, please consider citing:

@inproceedings{cao2024spiking,
  title={Spiking Diffusion Models},
  author={Cao, Jiahang and Guo, Hanzhong and Wang, Ziqing and Zhou, Deming and Cheng, Hao and Zhang, Qiang and Xu, Renjing},
  journal={arXiv preprint arXiv:2408.16467},
  year={2024}
}

Acknowledgements & Contact

We thank the authors (pytorch-ddpm, Fast-SNN, spikingjelly) for their open-sourced codes.

For any help or issues of this project, please contact [email protected].