-
Notifications
You must be signed in to change notification settings - Fork 275
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding u-net architectures * add model metadata * add loss functions * add unet * pytests for song unet * add unit tests for Dhariwal UNet * loss tests * revert resdiff changes to edm * fix typo * add generative recipes * update changelog * revert license update * clean up fid * ruff fixes * change example path * fix failing pytests * ruff exception for examples * formatting * fix circular import * fix circular import * deduplicate unit test name * updating regression test data * remove unet doctest * fix ddp * add EDM preconditioning for super-resolution tasks * address review comments-part 1 * address review comments- part 2 * remove reference to lustre
- Loading branch information
Showing
30 changed files
with
6,866 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
hydra: | ||
job: | ||
chdir: True | ||
run: | ||
dir: ./outputs/ | ||
|
||
|
||
# Main options | ||
outdir: ./results # Where to save the results | ||
data: ./data # Path to the dataset | ||
cond: true # Train class-conditional model | ||
arch: ddpmpp # Network architecture | ||
precond: edm # Preconditioning & loss function | ||
dataset: 'cifar10' | ||
|
||
# Hyperparameters | ||
duration: 200 # Training duration | ||
batch: 128 # Total batch size | ||
batch_gpu: null # Limit batch size per GPU | ||
cbase: null # Channel multiplier | ||
cres: null # Channels per resolution | ||
lr: 10e-4 # Learning rate | ||
ema: 0.5 # EMA half-life | ||
dropout: 0.13 # Dropout probability | ||
augment: null # Augment probability | ||
xflip: false # Enable dataset x-flips | ||
|
||
# Performance-related | ||
fp16: false # Enable mixed-precision training | ||
ls: 1.0 # Loss scaling | ||
bench: true # enable cuDNN benchmarking | ||
cache: true # Cache dataset in CPU memory | ||
workers: 1 # DataLoader worker processes | ||
fused_adam: false # Whether to use fused Adam optimizer | ||
|
||
# I/O-related | ||
desc: null # String to include in result dir name | ||
nosubdir: false # If True, do not create a subdirectory for results | ||
tick: 50 # How often to print progress | ||
snap: 50 # How often to save snapshots | ||
dump: 500 # How often to dump state | ||
seed: null # Random seed | ||
transfer: null # Transfer learning from network pickle | ||
resume: null # Resume from previous training state | ||
dry_run: false # Print training options and exit | ||
|
||
# Generation-related | ||
ckpt_filename: checkpoint # Checkpoint filename to be used for generation | ||
img_outdir: results_images # Where to save the output images | ||
gen_seeds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, | ||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, | ||
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, | ||
59, 60, 61, 62] # Random seeds used for generation | ||
subdirs: true # Create subdirectory for every 1000 seeds | ||
class_idx: null # Class label. Null is random | ||
max_batch_size: 64 # maximum batch size | ||
num_steps: 18 # Number of sampling steps | ||
sigma_min: null # Lowest noise level | ||
sigma_max: null # Highest noise level | ||
rho: 7 # Time step exponent | ||
s_churn: 0. # Stochasticity strength | ||
s_min: 0. # Stochasticity min noise level | ||
s_max: .inf # Stochasticity max noise level | ||
s_noise: 1. # Stochasticity noise inflation | ||
solver: heun # ODE solver [euler, heun] | ||
discretization: edm # Time step discretization [vp, ve, iddpm, edm] | ||
schedule: linear # noise schedule sigma(t) [vp, ve, linear] | ||
scaling: null # Signal scaling s(t) [vp, none] | ||
|
||
|
||
|
||
# # Weather-related | ||
# data_config: ? # String to include the data config | ||
# task: ? # String to include the task | ||
# data_type: ? # String to include the data type | ||
|
||
# # Regression | ||
# ckpt_unet: ? # Checkpoint for the UNet to predict the mean | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
hydra: | ||
job: | ||
chdir: True | ||
run: | ||
dir: ./outputs/ | ||
|
||
# Main options | ||
mode: calc # calc: calculate FID for a given set of images | ||
# ref: Calculate dataset reference statistics needed by 'calc' | ||
|
||
# FID options | ||
image_path: ./images # Path to the images | ||
ref_path: ./ref # Dataset reference statistics | ||
num_expected: 50000 # Number of images to use | ||
seed: 0 # Random seed for selecting the images | ||
batch: 64 # Maximum batch size | ||
|
||
# Reference statistics options | ||
dataset_path: ./data # Path to the dataset | ||
dest_path: ./dest.npz # Destination .npz file | ||
batch: 64 # Maximum batch size | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License.anguage governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .dataset import ImageFolderDataset |
Oops, something went wrong.