Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EDM Diffusion Models #193

Merged
merged 32 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4355cf4
adding u-net architectures
mnabian Oct 3, 2023
894eae7
add model metadata
mnabian Oct 3, 2023
3d8dcb0
add loss functions
mnabian Oct 3, 2023
858d63f
add unet
mnabian Oct 4, 2023
7fde867
pytests for song unet
mnabian Oct 6, 2023
3b48dcd
add unit tests for Dhariwal UNet
mnabian Oct 9, 2023
089aced
loss tests
mnabian Oct 9, 2023
a709c16
revert resdiff changes to edm
mnabian Oct 12, 2023
0dcffc9
fix typo
mnabian Oct 17, 2023
26de9fa
add generative recipes
mnabian Oct 17, 2023
7ac1dcb
Merge branch 'main' into fea-ext-diffusion
mnabian Oct 17, 2023
895f3f4
update changelog
mnabian Oct 17, 2023
fe046ae
update changelog
mnabian Oct 17, 2023
6fe6fb7
revert license update
mnabian Oct 17, 2023
812d1ec
clean up fid
mnabian Oct 17, 2023
272dfdf
ruff fixes
mnabian Oct 18, 2023
55988a8
change example path
mnabian Oct 18, 2023
22e2e25
fix failing pytests
mnabian Oct 18, 2023
bd584f3
ruff exception for examples
mnabian Oct 18, 2023
165a99f
Merge branch 'main' into fea-ext-diffusion
mnabian Oct 18, 2023
29c3304
formatting
mnabian Oct 18, 2023
416d1ba
fix circular import
mnabian Oct 18, 2023
559f5f0
fix circular import
mnabian Oct 18, 2023
7bf8e2b
deduplicate unit test name
mnabian Oct 18, 2023
56ea178
updating regression test data
mnabian Oct 18, 2023
3745435
remove unet doctest
mnabian Oct 18, 2023
2bac3c9
fix ddp
mnabian Oct 19, 2023
7ce7b72
add EDM preconditioning for super-resolution tasks
mnabian Oct 20, 2023
eb6cdc5
address review comments-part 1
mnabian Oct 25, 2023
47ed358
address review comments- part 2
mnabian Oct 26, 2023
b07243e
Merge branch 'main' into fea-ext-diffusion
mnabian Oct 26, 2023
fba9233
remove reference to lustre
mnabian Oct 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added distributed FFT utility.
- Added ruff as a linting tool.
- Ported utilities from Modulus Launch to main package.
- EDM diffusion models and recipes for training and sampling.

### Changed

Expand Down
96 changes: 96 additions & 0 deletions examples/generative/diffusion/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: True
run:
dir: ./outputs/


# Main options
outdir: ./results # Where to save the results
data: ./data # Path to the dataset
cond: true # Train class-conditional model
arch: ddpmpp # Network architecture
precond: edm # Preconditioning & loss function
dataset: 'cifar10'

# Hyperparameters
duration: 200 # Training duration
batch: 128 # Total batch size
batch_gpu: null # Limit batch size per GPU
cbase: null # Channel multiplier
cres: null # Channels per resolution
lr: 10e-4 # Learning rate
ema: 0.5 # EMA half-life
dropout: 0.13 # Dropout probability
augment: null # Augment probability
xflip: false # Enable dataset x-flips

# Performance-related
fp16: false # Enable mixed-precision training
ls: 1.0 # Loss scaling
bench: true # enable cuDNN benchmarking
cache: true # Cache dataset in CPU memory
workers: 1 # DataLoader worker processes
fused_adam: false # Whether to use fused Adam optimizer

# I/O-related
desc: null # String to include in result dir name
nosubdir: false # If True, do not create a subdirectory for results
tick: 50 # How often to print progress
snap: 50 # How often to save snapshots
dump: 500 # How often to dump state
seed: null # Random seed
transfer: null # Transfer learning from network pickle
resume: null # Resume from previous training state
dry_run: false # Print training options and exit

# Generation-related
ckpt_filename: checkpoint # Checkpoint filename to be used for generation
img_outdir: results_images # Where to save the output images
gen_seeds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
59, 60, 61, 62] # Random seeds used for generation
subdirs: true # Create subdirectory for every 1000 seeds
class_idx: null # Class label. Null is random
max_batch_size: 64 # maximum batch size
num_steps: 18 # Number of sampling steps
sigma_min: null # Lowest noise level
sigma_max: null # Highest noise level
rho: 7 # Time step exponent
s_churn: 0. # Stochasticity strength
s_min: 0. # Stochasticity min noise level
s_max: .inf # Stochasticity max noise level
s_noise: 1. # Stochasticity noise inflation
solver: heun # ODE solver [euler, heun]
discretization: edm # Time step discretization [vp, ve, iddpm, edm]
schedule: linear # noise schedule sigma(t) [vp, ve, linear]
scaling: null # Signal scaling s(t) [vp, none]



# # Weather-related
# data_config: ? # String to include the data config
# task: ? # String to include the task
# data_type: ? # String to include the data type

# # Regression
# ckpt_unet: ? # Checkpoint for the UNet to predict the mean
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved




42 changes: 42 additions & 0 deletions examples/generative/diffusion/conf/config_fid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: True
run:
dir: ./outputs/

# Main options
mode: calc # calc: calculate FID for a given set of images
# ref: Calculate dataset reference statistics needed by 'calc'

# FID options
image_path: ./images # Path to the images
ref_path: ./ref # Dataset reference statistics
num_expected: 50000 # Number of images to use
seed: 0 # Random seed for selecting the images
batch: 64 # Maximum batch size

# Reference statistics options
dataset_path: ./data # Path to the dataset
dest_path: ./dest.npz # Destination .npz file
batch: 64 # Maximum batch size







17 changes: 17 additions & 0 deletions examples/generative/diffusion/dataset/__init__.py
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.anguage governing permissions and
# limitations under the License.


from .dataset import ImageFolderDataset
Loading