diff --git a/CHANGELOG.md b/CHANGELOG.md index f41550d7b8..cd447460d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added distributed FFT utility. - Added ruff as a linting tool. - Ported utilities from Modulus Launch to main package. +- EDM diffusion models and recipes for training and sampling. ### Changed diff --git a/examples/generative/diffusion/conf/config.yaml b/examples/generative/diffusion/conf/config.yaml new file mode 100644 index 0000000000..2a038f7a8a --- /dev/null +++ b/examples/generative/diffusion/conf/config.yaml @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: True + run: + dir: ./outputs/ + + +# Main options +outdir: ./results # Where to save the results +data: ./data # Path to the dataset +cond: true # Train class-conditional model +arch: ddpmpp # Network architecture +precond: edm # Preconditioning & loss function +dataset: 'cifar10' + +# Hyperparameters +duration: 200 # Training duration +batch: 128 # Total batch size +batch_gpu: null # Limit batch size per GPU +cbase: null # Channel multiplier +cres: null # Channels per resolution +lr: 10e-4 # Learning rate +ema: 0.5 # EMA half-life +dropout: 0.13 # Dropout probability +augment: null # Augment probability +xflip: false # Enable dataset x-flips + +# Performance-related +fp16: false # Enable mixed-precision training +ls: 1.0 # Loss scaling +bench: true # enable cuDNN benchmarking +cache: true # Cache dataset in CPU memory +workers: 1 # DataLoader worker processes +fused_adam: false # Whether to use fused Adam optimizer + +# I/O-related +desc: null # String to include in result dir name +nosubdir: false # If True, do not create a subdirectory for results +tick: 50 # How often to print progress +snap: 50 # How often to save snapshots +dump: 500 # How often to dump state +seed: null # Random seed +transfer: null # Transfer learning from network pickle +resume: null # Resume from previous training state +dry_run: false # Print training options and exit + +# Generation-related +ckpt_filename: checkpoint # Checkpoint filename to be used for generation +img_outdir: results_images # Where to save the output images +gen_seeds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, + 59, 60, 61, 62] # Random seeds used for generation +subdirs: true # Create subdirectory for every 1000 seeds +class_idx: null # Class label. Null is random +max_batch_size: 64 # maximum batch size +num_steps: 18 # Number of sampling steps +sigma_min: null # Lowest noise level +sigma_max: null # Highest noise level +rho: 7 # Time step exponent +s_churn: 0. # Stochasticity strength +s_min: 0. # Stochasticity min noise level +s_max: .inf # Stochasticity max noise level +s_noise: 1. # Stochasticity noise inflation +solver: heun # ODE solver [euler, heun] +discretization: edm # Time step discretization [vp, ve, iddpm, edm] +schedule: linear # noise schedule sigma(t) [vp, ve, linear] +scaling: null # Signal scaling s(t) [vp, none] + + + +# # Weather-related +# data_config: ? # String to include the data config +# task: ? # String to include the task +# data_type: ? # String to include the data type + +# # Regression +# ckpt_unet: ? # Checkpoint for the UNet to predict the mean + + + + diff --git a/examples/generative/diffusion/conf/config_fid.yaml b/examples/generative/diffusion/conf/config_fid.yaml new file mode 100644 index 0000000000..96a03800c1 --- /dev/null +++ b/examples/generative/diffusion/conf/config_fid.yaml @@ -0,0 +1,42 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: True + run: + dir: ./outputs/ + +# Main options +mode: calc # calc: calculate FID for a given set of images + # ref: Calculate dataset reference statistics needed by 'calc' + + # FID options +image_path: ./images # Path to the images +ref_path: ./ref # Dataset reference statistics +num_expected: 50000 # Number of images to use +seed: 0 # Random seed for selecting the images +batch: 64 # Maximum batch size + +# Reference statistics options +dataset_path: ./data # Path to the dataset +dest_path: ./dest.npz # Destination .npz file +batch: 64 # Maximum batch size + + + + + + + diff --git a/examples/generative/diffusion/dataset/__init__.py b/examples/generative/diffusion/dataset/__init__.py new file mode 100644 index 0000000000..f987df14ee --- /dev/null +++ b/examples/generative/diffusion/dataset/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.anguage governing permissions and +# limitations under the License. + + +from .dataset import ImageFolderDataset diff --git a/examples/generative/diffusion/dataset/dataset.py b/examples/generative/diffusion/dataset/dataset.py new file mode 100644 index 0000000000..6589525f9e --- /dev/null +++ b/examples/generative/diffusion/dataset/dataset.py @@ -0,0 +1,276 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Streaming images and labels from datasets created with dataset_tool.py.""" + +import json +import os +import zipfile + +import numpy as np +import PIL.Image +import torch +from utils import EasyDict + +try: + import pyspng +except ImportError: + pyspng = None + + +class Dataset(torch.utils.data.Dataset): + """ + Abstract base class for datasets + """ + + def __init__( + self, + name, # Name of the dataset. + raw_shape, # Shape of the raw image data (NCHW). + max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. + use_labels=False, # Enable conditioning labels? False = label dimension is zero. + xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size. + random_seed=0, # Random seed to use when applying max_size. + cache=False, # Cache images in CPU memory? + ): + self._name = name + self._raw_shape = list(raw_shape) + self._use_labels = use_labels + self._cache = cache + self._cached_images = dict() # {raw_idx: np.ndarray, ...} + self._raw_labels = None + self._label_shape = None + + # Apply max_size. + self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) + if (max_size is not None) and (self._raw_idx.size > max_size): + np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) + self._raw_idx = np.sort(self._raw_idx[:max_size]) + + # Apply xflip. + self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if xflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) + + def _get_raw_labels(self): + if self._raw_labels is None: + self._raw_labels = self._load_raw_labels() if self._use_labels else None + if self._raw_labels is None: + self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) + assert isinstance(self._raw_labels, np.ndarray) + assert self._raw_labels.shape[0] == self._raw_shape[0] + assert self._raw_labels.dtype in [np.float32, np.int64] + if self._raw_labels.dtype == np.int64: + assert self._raw_labels.ndim == 1 + assert np.all(self._raw_labels >= 0) + return self._raw_labels + + def close(self): # to be overridden by subclass + pass + + def _load_raw_image(self, raw_idx): # to be overridden by subclass + raise NotImplementedError + + def _load_raw_labels(self): # to be overridden by subclass + raise NotImplementedError + + def __getstate__(self): + return dict(self.__dict__, _raw_labels=None) + + def __del__(self): + try: + self.close() + except: + pass + + def __len__(self): + return self._raw_idx.size + + def __getitem__(self, idx): + raw_idx = self._raw_idx[idx] + image = self._cached_images.get(raw_idx, None) + if image is None: + image = self._load_raw_image(raw_idx) + if self._cache: + self._cached_images[raw_idx] = image + assert isinstance(image, np.ndarray) + assert list(image.shape) == self.image_shape + assert image.dtype == np.uint8 + if self._xflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, :, ::-1] + return image.copy(), self.get_label(idx) + + def get_label(self, idx): + label = self._get_raw_labels()[self._raw_idx[idx]] + if label.dtype == np.int64: + onehot = np.zeros(self.label_shape, dtype=np.float32) + onehot[label] = 1 + label = onehot + return label.copy() + + def get_details(self, idx): + d = EasyDict() + d.raw_idx = int(self._raw_idx[idx]) + d.xflip = int(self._xflip[idx]) != 0 + d.raw_label = self._get_raw_labels()[d.raw_idx].copy() + return d + + @property + def name(self): + return self._name + + @property + def image_shape(self): + return list(self._raw_shape[1:]) + + @property + def num_channels(self): + assert len(self.image_shape) == 3 # CHW + return self.image_shape[0] + + @property + def resolution(self): + assert len(self.image_shape) == 3 # CHW + assert self.image_shape[1] == self.image_shape[2] + return self.image_shape[1] + + @property + def label_shape(self): + if self._label_shape is None: + raw_labels = self._get_raw_labels() + if raw_labels.dtype == np.int64: + self._label_shape = [int(np.max(raw_labels)) + 1] + else: + self._label_shape = raw_labels.shape[1:] + return list(self._label_shape) + + @property + def label_dim(self): + assert len(self.label_shape) == 1 + return self.label_shape[0] + + @property + def has_labels(self): + return any(x != 0 for x in self.label_shape) + + @property + def has_onehot_labels(self): + return self._get_raw_labels().dtype == np.int64 + + +class ImageFolderDataset(Dataset): + """ + Dataset subclass that loads images recursively from the specified directory + or ZIP file. + """ + + def __init__( + self, + path, # Path to directory or zip. + resolution=None, # Ensure specific resolution, None = highest available. + use_pyspng=True, # Use pyspng if available? + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self._path = path + self._use_pyspng = use_pyspng + self._zipfile = None + + if os.path.isdir(self._path): + self._type = "dir" + self._all_fnames = { + os.path.relpath(os.path.join(root, fname), start=self._path) + for root, _dirs, files in os.walk(self._path) + for fname in files + } + elif self._file_ext(self._path) == ".zip": + self._type = "zip" + self._all_fnames = set(self._get_zipfile().namelist()) + else: + raise IOError("Path must point to a directory or zip") + + PIL.Image.init() + self._image_fnames = sorted( + fname + for fname in self._all_fnames + if self._file_ext(fname) in PIL.Image.EXTENSION + ) + if len(self._image_fnames) == 0: + raise IOError("No image files found in the specified path") + + name = os.path.splitext(os.path.basename(self._path))[0] + raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) + if resolution is not None and ( + raw_shape[2] != resolution or raw_shape[3] != resolution + ): + raise IOError("Image files do not match the specified resolution") + super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) + + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def _get_zipfile(self): + assert self._type == "zip" + if self._zipfile is None: + self._zipfile = zipfile.ZipFile(self._path) + return self._zipfile + + def _open_file(self, fname): + if self._type == "dir": + return open(os.path.join(self._path, fname), "rb") + if self._type == "zip": + return self._get_zipfile().open(fname, "r") + return None + + def close(self): + try: + if self._zipfile is not None: + self._zipfile.close() + finally: + self._zipfile = None + + def __getstate__(self): + return dict(super().__getstate__(), _zipfile=None) + + def _load_raw_image(self, raw_idx): + fname = self._image_fnames[raw_idx] + with self._open_file(fname) as f: + if ( + self._use_pyspng + and pyspng is not None + and self._file_ext(fname) == ".png" + ): + image = pyspng.load(f.read()) + else: + image = np.array(PIL.Image.open(f)) + if image.ndim == 2: + image = image[:, :, np.newaxis] # HW => HWC + image = image.transpose(2, 0, 1) # HWC => CHW + return image + + def _load_raw_labels(self): + fname = "dataset.json" + if fname not in self._all_fnames: + return None + with self._open_file(fname) as f: + labels = json.load(f)["labels"] + if labels is None: + return None + labels = dict(labels) + labels = [labels[fname.replace("\\", "/")] for fname in self._image_fnames] + labels = np.array(labels) + labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) + return labels diff --git a/examples/generative/diffusion/dataset/dataset_tool.py b/examples/generative/diffusion/dataset/dataset_tool.py new file mode 100644 index 0000000000..fb00177d46 --- /dev/null +++ b/examples/generative/diffusion/dataset/dataset_tool.py @@ -0,0 +1,586 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Tool for creating ZIP/PNG based datasets.""" + +import argparse +import functools +import gzip +import io +import json +import os +import pickle +import re +import sys +import tarfile +import zipfile +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import PIL.Image +from tqdm import tqdm + +# ---------------------------------------------------------------------------- +# Parse a 'M,N' or 'MxN' integer tuple. +# Example: '4x2' returns (4,2) + + +def parse_tuple(s: str) -> Tuple[int, int]: + m = re.match(r"^(\d+)[x,](\d+)$", s) + if m: + return int(m.group(1)), int(m.group(2)) + raise click.ClickException(f"cannot parse tuple {s}") + + +# ---------------------------------------------------------------------------- + + +def maybe_min(a: int, b: Optional[int]) -> int: + if b is not None: + return min(a, b) + return a + + +# ---------------------------------------------------------------------------- + + +def file_ext(name: Union[str, Path]) -> str: + return str(name).split(".")[-1] + + +# ---------------------------------------------------------------------------- + + +def is_image_ext(fname: Union[str, Path]) -> bool: + ext = file_ext(fname).lower() + return f".{ext}" in PIL.Image.EXTENSION + + +# ---------------------------------------------------------------------------- + + +def open_image_folder(source_dir, *, max_images: Optional[int]): + input_images = [ + str(f) + for f in sorted(Path(source_dir).rglob("*")) + if is_image_ext(f) and os.path.isfile(f) + ] + arch_fnames = { + fname: os.path.relpath(fname, source_dir).replace("\\", "/") + for fname in input_images + } + max_idx = maybe_min(len(input_images), max_images) + + # Load labels. + labels = dict() + meta_fname = os.path.join(source_dir, "dataset.json") + if os.path.isfile(meta_fname): + with open(meta_fname, "r") as file: + data = json.load(file)["labels"] + if data is not None: + labels = {x[0]: x[1] for x in data} + + # No labels available => determine from top-level directory names. + if len(labels) == 0: + toplevel_names = { + arch_fname: arch_fname.split("/")[0] if "/" in arch_fname else "" + for arch_fname in arch_fnames.values() + } + toplevel_indices = { + toplevel_name: idx + for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values()))) + } + if len(toplevel_indices) > 1: + labels = { + arch_fname: toplevel_indices[toplevel_name] + for arch_fname, toplevel_name in toplevel_names.items() + } + + def iterate_images(): + for idx, fname in enumerate(input_images): + img = np.array(PIL.Image.open(fname)) + yield dict(img=img, label=labels.get(arch_fnames.get(fname))) + if idx >= max_idx - 1: + break + + return max_idx, iterate_images() + + +# ---------------------------------------------------------------------------- + + +def open_image_zip(source, *, max_images: Optional[int]): + with zipfile.ZipFile(source, mode="r") as z: + input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] + max_idx = maybe_min(len(input_images), max_images) + + # Load labels. + labels = dict() + if "dataset.json" in z.namelist(): + with z.open("dataset.json", "r") as file: + data = json.load(file)["labels"] + if data is not None: + labels = {x[0]: x[1] for x in data} + + def iterate_images(): + with zipfile.ZipFile(source, mode="r") as z: + for idx, fname in enumerate(input_images): + with z.open(fname, "r") as file: + img = np.array(PIL.Image.open(file)) + yield dict(img=img, label=labels.get(fname)) + if idx >= max_idx - 1: + break + + return max_idx, iterate_images() + + +# ---------------------------------------------------------------------------- + + +def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): + import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python + import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb + + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + max_idx = maybe_min(txn.stat()["entries"], max_images) + + def iterate_images(): + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + for idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) + if img is None: + raise IOError("cv2.imdecode failed") + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.array(PIL.Image.open(io.BytesIO(value))) + yield dict(img=img, label=None) + if idx >= max_idx - 1: + break + except: + print(sys.exc_info()[1]) + + return max_idx, iterate_images() + + +# ---------------------------------------------------------------------------- + + +def open_cifar10(tarball: str, *, max_images: Optional[int]): + images = [] + labels = [] + + with tarfile.open(tarball, "r:gz") as tar: + for batch in range(1, 6): + member = tar.getmember(f"cifar-10-batches-py/data_batch_{batch}") + with tar.extractfile(member) as file: + data = pickle.load(file, encoding="latin1") + images.append(data["data"].reshape(-1, 3, 32, 32)) + labels.append(data["labels"]) + + images = np.concatenate(images) + labels = np.concatenate(labels) + images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC + assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx - 1: + break + + return max_idx, iterate_images() + + +# ---------------------------------------------------------------------------- + + +def open_mnist(images_gz: str, *, max_images: Optional[int]): + labels_gz = images_gz.replace("-images-idx3-ubyte.gz", "-labels-idx1-ubyte.gz") + assert labels_gz != images_gz + images = [] + labels = [] + + with gzip.open(images_gz, "rb") as f: + images = np.frombuffer(f.read(), np.uint8, offset=16) + with gzip.open(labels_gz, "rb") as f: + labels = np.frombuffer(f.read(), np.uint8, offset=8) + + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0, 0), (2, 2), (2, 2)], "constant", constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx - 1: + break + + return max_idx, iterate_images() + + +# ---------------------------------------------------------------------------- + + +def make_transform( + transform: Optional[str], output_width: Optional[int], output_height: Optional[int] +) -> Callable[[np.ndarray], Optional[np.ndarray]]: + def scale(width, height, img): + w = img.shape[1] + h = img.shape[0] + if width == w and height == h: + return img + img = PIL.Image.fromarray(img) + ww = width if width is not None else w + hh = height if height is not None else h + img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) + return np.array(img) + + def center_crop(width, height, img): + crop = np.min(img.shape[:2]) + img = img[ + (img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, + (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2, + ] + if img.ndim == 2: + img = img[:, :, np.newaxis].repeat(3, axis=2) + img = PIL.Image.fromarray(img, "RGB") + img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) + return np.array(img) + + def center_crop_wide(width, height, img): + ch = int(np.round(width * img.shape[0] / img.shape[1])) + if img.shape[1] < width or ch < height: + return None + + img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] + if img.ndim == 2: + img = img[:, :, np.newaxis].repeat(3, axis=2) + img = PIL.Image.fromarray(img, "RGB") + img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) + img = np.array(img) + + canvas = np.zeros([width, width, 3], dtype=np.uint8) + canvas[(width - height) // 2 : (width + height) // 2, :] = img + return canvas + + if transform is None: + return functools.partial(scale, output_width, output_height) + if transform == "center-crop": + if output_width is None or output_height is None: + raise click.ClickException( + "must specify --resolution=WxH when using " + transform + "transform" + ) + return functools.partial(center_crop, output_width, output_height) + if transform == "center-crop-wide": + if output_width is None or output_height is None: + raise click.ClickException( + "must specify --resolution=WxH when using " + transform + " transform" + ) + return functools.partial(center_crop_wide, output_width, output_height) + assert False, "unknown transform" + + +# ---------------------------------------------------------------------------- + + +def open_dataset(source, *, max_images: Optional[int]): + if os.path.isdir(source): + if source.rstrip("/").endswith("_lmdb"): + return open_lmdb(source, max_images=max_images) + else: + return open_image_folder(source, max_images=max_images) + elif os.path.isfile(source): + if os.path.basename(source) == "cifar-10-python.tar.gz": + return open_cifar10(source, max_images=max_images) + elif os.path.basename(source) == "train-images-idx3-ubyte.gz": + return open_mnist(source, max_images=max_images) + elif file_ext(source) == "zip": + return open_image_zip(source, max_images=max_images) + else: + assert False, "unknown archive type" + else: + raise click.ClickException(f"Missing input file or directory: {source}") + + +# ---------------------------------------------------------------------------- + + +def open_dest( + dest: str, +) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: + dest_ext = file_ext(dest) + + if dest_ext == "zip": + if os.path.dirname(dest) != "": + os.makedirs(os.path.dirname(dest), exist_ok=True) + zf = zipfile.ZipFile(file=dest, mode="w", compression=zipfile.ZIP_STORED) + + def zip_write_bytes(fname: str, data: Union[bytes, str]): + zf.writestr(fname, data) + + return "", zip_write_bytes, zf.close + else: + # If the output folder already exists, check that is is + # empty. + # + # Note: creating the output directory is not strictly + # necessary as folder_write_bytes() also mkdirs, but it's better + # to give an error message earlier in case the dest folder + # somehow cannot be created. + if os.path.isdir(dest) and len(os.listdir(dest)) != 0: + raise click.ClickException("--dest folder must be empty") + os.makedirs(dest, exist_ok=True) + + def folder_write_bytes(fname: str, data: Union[bytes, str]): + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, "wb") as fout: + if isinstance(data, str): + data = data.encode("utf8") + fout.write(data) + + return dest, folder_write_bytes, lambda: None + + +# ---------------------------------------------------------------------------- + + +def main( + source: str, + dest: str, + max_images: Optional[int], + transform: Optional[str], # ["center-crop", "center-crop-wide"] + resolution: Optional[Tuple[int, int]], # WxH +): + """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. + + The input dataset format is guessed from the --source argument: + + \b + --source *_lmdb/ Load LSUN dataset + --source cifar-10-python.tar.gz Load CIFAR-10 dataset + --source train-images-idx3-ubyte.gz Load MNIST dataset + --source path/ Recursively load all images from path/ + --source dataset.zip Recursively load all images from dataset.zip + + Specifying the output format and path: + + \b + --dest /path/to/dir Save output files under /path/to/dir + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip + + The output dataset format can be either an image folder or an uncompressed zip archive. + Zip archives makes it easier to move datasets around file servers and clusters, and may + offer better training performance on network file systems. + + Images within the dataset archive will be stored as uncompressed PNG. + Uncompresed PNGs can be efficiently decoded in the training loop. + + Class labels are stored in a file called 'dataset.json' that is stored at the + dataset root folder. This file has the following structure: + + \b + { + "labels": [ + ["00000/img00000000.png",6], + ["00000/img00000001.png",9], + ... repeated for every image in the datase + ["00049/img00049999.png",1] + ] + } + + If the 'dataset.json' file cannot be found, class labels are determined from + top-level directory names. + + Image scale/crop and resolution requirements: + + Output images must be square-shaped and they must all have the same power-of-two + dimensions. + + To scale arbitrary input image size to a specific width and height, use the + --resolution option. Output resolution will be either the original + input resolution (if resolution was not specified) or the one specified with + --resolution option. + + Use the --transform=center-crop or --transform=center-crop-wide options to apply a + center crop transform on the input image. These options should be used with the + --resolution option. For example: + + \b + python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ + --transform=center-crop-wide --resolution=512x384 + """ + + PIL.Image.init() + + if dest == "": + raise click.ClickException( + "--dest output filename or directory must not be an empty string" + ) + + num_files, input_iter = open_dataset(source, max_images=max_images) + archive_root_dir, save_bytes, close_dest = open_dest(dest) + + if resolution is None: + resolution = (None, None) + transform_image = make_transform(transform, *resolution) + + dataset_attrs = None + + labels = [] + for idx, image in tqdm(enumerate(input_iter), total=num_files): + idx_str = f"{idx:08d}" + archive_fname = f"{idx_str[:5]}/img{idx_str}.png" + + # Apply crop and resize. + img = transform_image(image["img"]) + if img is None: + continue + + # Error check to require uniform image attributes across + # the whole dataset. + channels = img.shape[2] if img.ndim == 3 else 1 + cur_image_attrs = { + "width": img.shape[1], + "height": img.shape[0], + "channels": channels, + } + if dataset_attrs is None: + dataset_attrs = cur_image_attrs + width = dataset_attrs["width"] + height = dataset_attrs["height"] + if width != height: + raise click.ClickException( + f"Image dimensions after scale and crop are required to be square. Got {width}x{height}" + ) + if dataset_attrs["channels"] not in [1, 3]: + raise click.ClickException( + "Input images must be stored as RGB or grayscale" + ) + if width != 2 ** int(np.floor(np.log2(width))): + raise click.ClickException( + "Image width/height after scale and crop are required to be power-of-two" + ) + elif dataset_attrs != cur_image_attrs: + err = [ + f" dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}" + for k in dataset_attrs.keys() + ] + raise click.ClickException( + f"Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n" + + "\n".join(err) + ) + + # Save the image as an uncompressed PNG. + img = PIL.Image.fromarray(img, {1: "L", 3: "RGB"}[channels]) + image_bits = io.BytesIO() + img.save(image_bits, format="png", compress_level=0, optimize=False) + save_bytes( + os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer() + ) + labels.append( + [archive_fname, image["label"]] if image["label"] is not None else None + ) + + metadata = {"labels": labels if all(x is not None for x in labels) else None} + save_bytes(os.path.join(archive_root_dir, "dataset.json"), json.dumps(metadata)) + close_dest() + + +# ---------------------------------------------------------------------------- + + +def parse_tuple(string): + try: + # split the string and convert each part to an integer + parsed = tuple(map(int, string.split("x"))) + if len(parsed) != 2: + raise ValueError + except ValueError: + msg = f"{string} is an invalid resolution format. It should be WxH (e.g., 512x512)" + raise argparse.ArgumentTypeError(msg) + return parsed + + +# Create the parser +parser = argparse.ArgumentParser(description="Process some images.") + +# Add the arguments +parser.add_argument( + "--source", + help="Input directory or archive name", + metavar="PATH", + type=str, + required=True, +) + +parser.add_argument( + "--dest", + help="Output directory or archive name", + metavar="PATH", + type=str, + required=True, +) + +parser.add_argument( + "--max-images", + help="Maximum number of images to output", + metavar="INT", + type=int, +) + +parser.add_argument( + "--transform", + help="Input crop/resize mode", + metavar="MODE", + type=str, + choices=["center-crop", "center-crop-wide"], +) + +parser.add_argument( + "--resolution", + help="Output resolution (e.g., 512x512)", + metavar="WxH", + type=parse_tuple, +) + +# Parse the arguments +args = parser.parse_args() + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + + main( + source=args.source, + dest=args.dest, + max_images=args.max_images, + transform=args.transform, + resolution=args.resolution, + ) + +# ---------------------------------------------------------------------------- diff --git a/examples/generative/diffusion/fid.py b/examples/generative/diffusion/fid.py new file mode 100644 index 0000000000..291ba98a6f --- /dev/null +++ b/examples/generative/diffusion/fid.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO (mnabian) refactor, generalize + +"""Script for calculating Frechet Inception Distance (FID).""" + +import os +import pickle + +import hydra +import numpy as np +import torch +import tqdm +from dataset import ImageFolderDataset +from omegaconf import DictConfig +from utils import open_url + +from modulus.metrics.diffusion import calculate_fid_from_inception_stats +from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper + + +def calculate_inception_stats( + image_path, + dist, + logger0, + num_expected=None, + seed=0, + max_batch_size=64, + num_workers=3, + prefetch_factor=2, +): + device = dist.device + # Rank 0 goes first. + if dist.world_size > 1 and dist.rank != 0: + torch.distributed.barrier() + + # Load Inception-v3 model. + # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + logger0.info("Loading Inception-v3 model...") + detector_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl" + detector_kwargs = dict(return_features=True) + feature_dim = 2048 + with open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: + detector_net = pickle.load(f).to(device) + + # List images. + logger0.info(f'Loading images from "{image_path}"...') + dataset_obj = ImageFolderDataset( + path=image_path, max_size=num_expected, random_seed=seed + ) + if num_expected is not None and len(dataset_obj) < num_expected: + raise ValueError( + f"Found {len(dataset_obj)} images, but expected at least {num_expected}" + ) + if len(dataset_obj) < 2: + raise ValueError( + f"Found {len(dataset_obj)} images, but need at least 2 to compute statistics" + ) + + # Other ranks follow. + if dist.world_size > 1 and dist.rank == 0: + torch.distributed.barrier() + + # Divide images into batches. + num_batches = ( + (len(dataset_obj) - 1) // (max_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) + rank_batches = all_batches[dist.rank :: dist.world_size] + data_loader = torch.utils.data.DataLoader( + dataset_obj, + batch_sampler=rank_batches, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + # Accumulate statistics. + logger0.info(f"Calculating statistics for {len(dataset_obj)} images...") + mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) + sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) + for images, _ in tqdm.tqdm( + data_loader, unit="batch", disable=(dist.get_rank() != 0) + ): + if dist.world_size > 1: + torch.distributed.barrier() + if images.shape[0] == 0: + continue + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) + mu += features.sum(0) + sigma += features.T @ features + + # Calculate grand totals. + if dist.world_size > 1: + torch.distributed.all_reduce(mu) + torch.distributed.all_reduce(sigma) + mu /= len(dataset_obj) + sigma -= mu.ger(mu) * len(dataset_obj) + sigma /= len(dataset_obj) - 1 + return mu.cpu().numpy(), sigma.cpu().numpy() + + +def calc(image_path, ref_path, num_expected, seed, batch, dist, logger, logger0): + """Calculate FID for a given set of images.""" + + logger0.info(f'Loading dataset reference statistics from "{ref_path}"...') + ref = None + if dist.rank == 0: + with open_url(ref_path) as f: + ref = dict(np.load(f)) + + mu, sigma = calculate_inception_stats( + image_path=image_path, + dist=dist, + logger0=logger0, + num_expected=num_expected, + seed=seed, + max_batch_size=batch, + ) + logger0.info("Calculating FID...") + if dist.rank == 0: + fid = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"]) + logger.info(f"{fid:g}") + if dist.world_size > 1: + torch.distributed.barrier() + + +def ref(dataset_path, dest_path, batch, dist, logger0): + """Calculate dataset reference statistics needed by 'calc'.""" + + mu, sigma = calculate_inception_stats( + image_path=dataset_path, dist=dist, logger0=logger0, max_batch_size=batch + ) + logger0.info(f'Saving dataset reference statistics to "{dest_path}"...') + if dist.rank == 0: + if os.path.dirname(dest_path): + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + np.savez(dest_path, mu=mu, sigma=sigma) + + if dist.world_size > 1: + torch.distributed.barrier() + logger0.info("Done.") + + +# ---------------------------------------------------------------------------- + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config_fid") +def main(cfg: DictConfig) -> None: + + """Calculate Frechet Inception Distance (FID).""" + + # Initialize distributed manager. + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize logger. + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging() + + if cfg.mode == "calc": + calc( + cfg.image_path, + cfg.ref_path, + cfg.num_expected, + cfg.seed, + cfg.batch, + dist, + logger, + logger0, + ) + elif cfg.mode == "ref": + ref(cfg.dataset_path, cfg.dest_path, cfg.batch, dist, logger0) + else: + raise ValueError(f"Unknown mode {cfg.mode}") + + +if __name__ == "__main__": + main() diff --git a/examples/generative/diffusion/generate.py b/examples/generative/diffusion/generate.py new file mode 100644 index 0000000000..d868bc67cb --- /dev/null +++ b/examples/generative/diffusion/generate.py @@ -0,0 +1,334 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle # TODO remove + +import hydra +import numpy as np +import PIL.Image +import torch +import tqdm +from omegaconf import DictConfig +from utils import StackedRandomGenerator, open_url + +from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper + + +def sampler( + net, + latents, + class_labels=None, + randn_like=torch.randn_like, + num_steps=18, + sigma_min=None, + sigma_max=None, + rho=7, + solver="heun", + discretization="edm", + schedule="linear", + scaling="none", + epsilon_s=1e-3, + C_1=0.001, + C_2=0.008, + M=1000, + alpha=1, + s_churn=0, + s_min=0, + s_max=float("inf"), + s_noise=1, +): + """ + Generalized sampler, representing the superset of all sampling methods discussed + in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + """ + if solver not in ["euler", "heun"]: + raise ValueError(f'Invalid solver "{solver}"') + if discretization not in ["vp", "ve", "iddpm", "edm"]: + raise ValueError(f'Invalid discretization "{discretization}"') + if schedule not in ["vp", "ve", "linear"]: + raise ValueError(f'Invalid schedule "{schedule}"') + if scaling is not None and scaling not in ["vp"]: + raise ValueError(f'Invalid scaling "{scaling}"') + + # Helper functions for VP & VE noise level schedules. + vp_sigma = ( + lambda beta_d, beta_min: lambda t: ( + np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1 + ) + ** 0.5 + ) + vp_sigma_deriv = ( + lambda beta_d, beta_min: lambda t: 0.5 + * (beta_min + beta_d * t) + * (sigma(t) + 1 / sigma(t)) + ) + vp_sigma_inv = ( + lambda beta_d, beta_min: lambda sigma: ( + (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min + ) + / beta_d + ) + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma**2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) + sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ + discretization + ] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) + sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = ( + 2 + * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) + / (epsilon_s - 1) + ) + vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == "vp": + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == "ve": + orig_t_steps = (sigma_max**2) * ( + (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) + ) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == "iddpm": + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 + ).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[ + ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) + .round() + .to(torch.int64) + ] + else: # edm sigma steps + sigma_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + + # Define noise level schedule. + if schedule == "vp": + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == "ve": + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == "vp": + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = ( + min(s_churn / num_steps, np.sqrt(2) - 1) + if s_min <= sigma(t_cur) <= s_max + else 0 + ) + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + ( + sigma(t_hat) ** 2 - sigma(t_cur) ** 2 + ).clip(min=0).sqrt() * s(t_hat) * s_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) + d_cur = ( + sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) + ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == "euler" or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + assert solver == "heun" + denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to( + torch.float64 + ) + d_prime = ( + sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) + ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ( + (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime + ) + + return x_next + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """Generate random images using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + """ + ckpt_filename = cfg.ckpt_filename + img_outdir = cfg.img_outdir + subdirs = cfg.subdirs + gen_seeds = cfg.gen_seeds + class_idx = cfg.class_idx + max_batch_size = cfg.max_batch_size + + # Initialize distributed manager. + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger. + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging() + + num_batches = ( + (len(gen_seeds) - 1) // (max_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.as_tensor(gen_seeds).tensor_split(num_batches) + rank_batches = all_batches[dist.rank :: dist.world_size] + + # Rank 0 goes first. + if dist.world_size > 1 and dist.rank != 0: + torch.distributed.barrier() + + # Load network. + logger0.info(f'Loading network from "{ckpt_filename}"...') + with open_url(ckpt_filename, verbose=(dist.rank == 0)) as f: + net = pickle.load(f)["ema"].to(device) + + # Other ranks follow. + if dist.world_size > 1 and dist.rank == 0: + torch.distributed.barrier() + + # Loop over batches. + logger0.info(f'Generating {len(gen_seeds)} images to "{img_outdir}"...') + for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(dist.rank != 0)): + if dist.world_size > 1: + torch.distributed.barrier() + batch_size = len(batch_seeds) + if batch_size == 0: + continue + + # Pick latents and labels. + rnd = StackedRandomGenerator(device, batch_seeds) + latents = rnd.randn( + [batch_size, net.img_channels, net.img_resolution, net.img_resolution], + device=device, + ) + class_labels = None + if net.label_dim: + class_labels = torch.eye(net.label_dim, device=device)[ + rnd.randint(net.label_dim, size=[batch_size], device=device) + ] + if class_idx is not None: + class_labels[:, :] = 0 + class_labels[:, class_idx] = 1 + + # Generate images. + images = sampler( + net, + latents, + class_labels, + randn_like=rnd.randn_like, + num_steps=cfg.num_steps, + sigma_min=cfg.sigma_min, + sigma_max=cfg.sigma_max, + rho=cfg.rho, + solver=cfg.solver, + discretization=cfg.discretization, + schedule=cfg.schedule, + scaling=cfg.scaling, + epsilon_s=1e-3, + C_1=0.001, + C_2=0.008, + M=1000, + alpha=1, + s_churn=cfg.s_churn, + s_min=cfg.s_min, + s_max=cfg.s_max, + s_noise=cfg.s_noise, + ) + + # Save images. + images_np = ( + (images * 127.5 + 128) + .clip(0, 255) + .to(torch.uint8) + .permute(0, 2, 3, 1) + .cpu() + .numpy() + ) + for seed, image_np in zip(batch_seeds, images_np): + image_dir = ( + os.path.join(img_outdir, f"{seed-seed%1000:06d}") + if subdirs + else img_outdir + ) + os.makedirs(image_dir, exist_ok=True) + image_path = os.path.join(image_dir, f"{seed:06d}.png") + if image_np.shape[2] == 1: + PIL.Image.fromarray(image_np[:, :, 0], "L").save(image_path) + else: + PIL.Image.fromarray(image_np, "RGB").save(image_path) + + # Done. + if dist.world_size > 1: + torch.distributed.barrier() + logger0.info("Done.") + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +# ---------------------------------------------------------------------------- diff --git a/examples/generative/diffusion/train.py b/examples/generative/diffusion/train.py new file mode 100644 index 0000000000..beb3a51faf --- /dev/null +++ b/examples/generative/diffusion/train.py @@ -0,0 +1,285 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.anguage governing permissions and +# limitations under the License. + +"""Train diffusion-based generative model using the techniques described in the +paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import os + +os.environ["TORCHELASTIC_ENABLE_FILE_TIMER"] = "1" # TODO is this needed? + +import json +import re + +import hydra +import torch +from omegaconf import DictConfig +from training_loop import training_loop +from utils import EasyDict, construct_class_by_name + +from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper + +try: + from apex.optimizers import FusedAdam + + apex_imported = True +except ImportError: + apex_imported = False + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """Train diffusion-based generative model using the techniques described in the + paper "Elucidating the Design Space of Diffusion-Based Generative Models". + + Examples: + + \b + # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs + torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ + --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp + """ + + # Initialize distributed manager. + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize logger. + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging() + + # TODO add mlflow/wandb logging + + # Initialize config dict. + c = EasyDict() + c.dataset = cfg.dataset + c.dataset_kwargs = EasyDict( + class_name="dataset.ImageFolderDataset", + path=cfg.data, + use_labels=cfg.cond, + xflip=cfg.xflip, + cache=cfg.cache, + ) + c.data_loader_kwargs = EasyDict( + pin_memory=True, num_workers=cfg.workers, prefetch_factor=2 + ) + c.network_kwargs = EasyDict() + c.loss_kwargs = EasyDict() + c.optimizer_kwargs = EasyDict( + class_name="apex.optimizers.FusedAdam" + if apex_imported and cfg.fused_adam + else "torch.optim.Adam", + lr=cfg.lr, + betas=[0.9, 0.999], + eps=1e-8, + ) + dataset_name = cfg.dataset + + # Validate dataset options. + try: + dataset_obj = construct_class_by_name(**c.dataset_kwargs) + dataset_name = dataset_obj.name + c.dataset_kwargs.resolution = ( + dataset_obj.resolution + ) # be explicit about dataset resolution + c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size + if cfg.cond and not dataset_obj.has_labels: + raise ValueError("cond=True requires labels specified in dataset.json") + del dataset_obj # conserve memory + except IOError as err: + raise ValueError(f"data: {err}") + + # Network architecture. + # if cfg.arch == 'ddpmpp-cwb-v2': + # c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') #, attn_resolutions=[28] + # c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[1,2,2,4,4,8], attn_resolutions=[14]) #era5-cwb, larger run, 448x448 + + # elif cfg.arch == 'ddpmpp-cwb-v1': + # c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') #, attn_resolutions=[28] + # c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[1,2,2,4,4], attn_resolutions=[28]) #era5-cwb, 448x448 + + # elif cfg.arch == 'ddpmpp-cwb-v0-regression': + # c.network_kwargs.update(model_type='SongUNet', embedding_type='zero', encoder_type='standard', decoder_type='standard') #, attn_resolutions=[28] + # c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[1,2,2,2,2], attn_resolutions=[28]) #era5-cwb, 448x448 + + # elif cfg.arch == 'ddpmpp-cwb-v0': + # c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') #, attn_resolutions=[28] + # c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[1,2,2,2,2], attn_resolutions=[28]) #era5-cwb, 448x448 + + # elif cfg.arch == 'ddpmpp-cifar': + # c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') #, attn_resolutions=[28] + # c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2]) #cifar-10, 32x32 + + # elif cfg.arch == 'ncsnpp': + # c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') + # c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2]) + + if cfg.arch == "ddpmpp": + c.network_kwargs.update( + model_type="SongUNet", + embedding_type="positional", + encoder_type="standard", + decoder_type="standard", + ) + c.network_kwargs.update( + channel_mult_noise=1, + resample_filter=[1, 1], + model_channels=128, + channel_mult=[2, 2, 2], + ) + elif cfg.arch == "ncsnpp": + c.network_kwargs.update( + model_type="SongUNet", + embedding_type="fourier", + encoder_type="residual", + decoder_type="standard", + ) + c.network_kwargs.update( + channel_mult_noise=2, + resample_filter=[1, 3, 3, 1], + model_channels=128, + channel_mult=[2, 2, 2], + ) + else: + assert cfg.arch == "adm" + c.network_kwargs.update( + model_type="DhariwalUNet", model_channels=192, channel_mult=[1, 2, 3, 4] + ) + + # Preconditioning & loss function. + if cfg.precond == "vp": + c.network_kwargs.class_name = "modulus.models.diffusion.VPPrecond" + c.loss_kwargs.class_name = "modulus.metrics.diffusion.VPLoss" + elif cfg.precond == "ve": + c.network_kwargs.class_name = "modulus.models.diffusion.VEPrecond" + c.loss_kwargs.class_name = "modulus.metrics.diffusion.VELoss" + elif cfg.precond == "edm": + c.network_kwargs.class_name = "modulus.models.diffusion.EDMPrecond" + c.loss_kwargs.class_name = "modulus.metrics.diffusion.EDMLoss" + # elif cfg.precond == 'unetregression': + # c.network_kwargs.class_name = 'training.networks.UNet' + # c.loss_kwargs.class_name = 'training.loss.RegressionLoss' + # elif cfg.precond == 'mixture': + # c.network_kwargs.class_name = 'training.networks.EDMPrecond' + # c.loss_kwargs.class_name = 'training.loss.MixtureLoss' + # elif cfg.precond == 'resloss': + # c.network_kwargs.class_name = 'training.networks.EDMPrecond' + # c.loss_kwargs.class_name = 'training.loss.ResLoss' + + # Network options. + if cfg.cbase is not None: + c.network_kwargs.model_channels = cfg.cbase + if cfg.cres is not None: + c.network_kwargs.channel_mult = cfg.cres + if cfg.augment: + raise NotImplementedError("Augmentation is not implemented") + c.network_kwargs.update(dropout=cfg.dropout, use_fp16=cfg.fp16) + + # Training options. + c.total_kimg = max(int(cfg.duration * 1000), 1) + c.ema_halflife_kimg = int(cfg.ema * 1000) + c.update(batch_size=cfg.batch, batch_gpu=cfg.batch_gpu) + c.update(loss_scaling=cfg.ls, cudnn_benchmark=cfg.bench) + c.update(kimg_per_tick=cfg.tick, snapshot_ticks=cfg.snap, state_dump_ticks=cfg.dump) + + # Random seed. + if cfg.seed is not None: + c.seed = cfg.seed + else: + seed = torch.randint(1 << 31, size=[], device=dist.device) + if dist.distributed: + torch.distributed.broadcast(seed, src=0) # TODO check if this fails + c.seed = int(seed) + + # check if resume.txt exists + resume_path = os.path.join(cfg.outdir, "resume.txt") + if os.path.exists(resume_path): + with open(resume_path, "r") as f: + cfg.resume = f.read() + f.close() + + logger0.info(f"cfg.resume: {cfg.resume}") + + # Transfer learning and resume. + if cfg.transfer is not None: + if cfg.resume is not None: + raise ValueError("transfer and resume cannot be specified at the same time") + c.resume_pkl = cfg.transfer + c.ema_rampup_ratio = None + elif cfg.resume is not None: # TODO replace prints with Modulus logger + print("gets into elif cfg.resume is not None ...") + match = re.fullmatch(r"training-state-(\d+).pt", os.path.basename(cfg.resume)) + print("match", match) + print("match.group(1)", match.group(1)) + c.resume_pkl = os.path.join( + os.path.dirname(cfg.resume), f"network-snapshot-{match.group(1)}.pkl" + ) + c.resume_kimg = int(match.group(1)) + c.resume_state_dump = cfg.resume + logger0.info(f"c.resume_pkl: {c.resume_pkl}") + logger0.info(f"c.resume_kimg: {c.resume_kimg}") + logger0.info(f"c.resume_state_dump: {c.resume_state_dump}") + + # Description string. + cond_str = "cond" if c.dataset_kwargs.use_labels else "uncond" + dtype_str = "fp16" if c.network_kwargs.use_fp16 else "fp32" + desc = f"{dataset_name:s}-{cond_str:s}-{cfg.arch:s}-{cfg.precond:s}-gpus{dist.world_size:d}-batch{c.batch_size:d}-{dtype_str:s}" + if cfg.desc is not None: + desc += f"-{cfg.desc}" + + c.run_dir = cfg.outdir + + # # Weather data + # c.data_type = cfg.data_type + # c.data_config = cfg.data_config + # c.task = cfg.task + + # Print options. # TODO replace prints with Modulus logger + logger0.info("Training options:") + logger0.info(json.dumps(c, indent=2)) + logger0.info(f"Output directory: {c.run_dir}") + logger0.info(f"Dataset path: {c.dataset_kwargs.path}") + logger0.info(f"Class-conditional: {c.dataset_kwargs.use_labels}") + logger0.info(f"Network architecture: {cfg.arch}") + logger0.info(f"Preconditioning & loss: {cfg.precond}") + logger0.info(f"Number of GPUs: {dist.world_size}") + logger0.info(f"Batch size: {c.batch_size}") + logger0.info(f"Mixed-precision: {c.network_kwargs.use_fp16}") + + # Dry run? + if cfg.dry_run: + logger0.info("Dry run; exiting.") + return + + # Create output directory. + logger0.info("Creating output directory...") + if dist.rank == 0: + os.makedirs(c.run_dir, exist_ok=True) + with open(os.path.join(c.run_dir, "training_options.json"), "wt") as f: + json.dump(c, f, indent=2) + # utils.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) + + # Train. + training_loop(**c, dist=dist, logger0=logger0) + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +# ---------------------------------------------------------------------------- diff --git a/examples/generative/diffusion/training_loop.py b/examples/generative/diffusion/training_loop.py new file mode 100644 index 0000000000..c5ce397d21 --- /dev/null +++ b/examples/generative/diffusion/training_loop.py @@ -0,0 +1,472 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main training loop.""" + +import copy +import json +import os +import pickle # TODO remove +import time + +import numpy as np +import psutil +import torch +from torch.nn.parallel import DistributedDataParallel +from training_stats import default_collector, report, report0 +from utils import ( + InfiniteSampler, + check_ddp_consistency, + construct_class_by_name, + copy_params_and_buffers, + ddp_sync, + format_time, + open_url, + print_module_summary, +) + +# # weather related +# from .YParams import YParams +# from .dataset import Era5Dataset, CWBDataset, CWBERA5DatasetV2, ZarrDataset + +# ---------------------------------------------------------------------------- + + +def training_loop( + run_dir=".", # Output directory. + dataset=None, # The dataset. Choose from ['cifar10']. + dataset_kwargs={}, # Options for training set. + data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. + network_kwargs={}, # Options for model and preconditioning. + loss_kwargs={}, # Options for loss function. + optimizer_kwargs={}, # Options for optimizer. + augment_kwargs=None, # Options for augmentation pipeline, None = disable. + seed=0, # Global random seed. + batch_size=512, # Total batch size for one training iteration. + batch_gpu=None, # Limit batch size per GPU, None = no limit. + total_kimg=200000, # Training duration, measured in thousands of training images. + ema_halflife_kimg=500, # Half-life of the exponential moving average (EMA) of model weights. + ema_rampup_ratio=0.05, # EMA ramp-up coefficient, None = no rampup. + lr_rampup_kimg=10000, # Learning rate ramp-up duration. + loss_scaling=1, # Loss scaling factor for reducing FP16 under/overflows. + kimg_per_tick=50, # Interval of progress prints. + snapshot_ticks=50, # How often to save network snapshots, None = disable. + state_dump_ticks=500, # How often to dump training state, None = disable. + resume_pkl=None, # Start from the given network snapshot, None = random initialization. + resume_state_dump=None, # Start from the given training state, None = reset training state. + resume_kimg=0, # Start from the given training progress. + cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? + # data_type=None, + # data_config=None, + # task=None, + dist=None, # distributed object + logger0=None, # rank 0 logger +): + # Initialize. + start_time = time.time() + device = dist.device + np.random.seed((seed * dist.world_size + dist.rank) % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + # Select batch size per GPU. + batch_gpu_total = batch_size // dist.world_size + logger0.info(f"batch_gpu: {batch_gpu}") + if batch_gpu is None or batch_gpu > batch_gpu_total: + batch_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_gpu + assert batch_size == batch_gpu * num_accumulation_rounds * dist.world_size + + # Load dataset + supported_datasets = ["cifar10"] + if dataset is None: + raise RuntimeError("Please specify the dataset.") + if dataset not in supported_datasets: + raise ValueError( + f'Invalid dataset: "{dataset}".' "Supported datasets: {supported_datasets}." + ) + logger0.info(f"Loading {dataset} dataset...") + + # Load dataset: cifar10 + dataset_obj = construct_class_by_name( + **dataset_kwargs + ) # subclass of training.dataset.Dataset + dataset_sampler = InfiniteSampler( + dataset=dataset_obj, + rank=dist.rank, + num_replicas=dist.world_size, + seed=seed, + ) + dataset_iterator = iter( + torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_gpu, + **data_loader_kwargs, + ) + ) + + # # Load dataset: weather + # yparams = YParams(data_type + '.yaml', config_name=data_config) + # if data_type == 'era5': + # dataset_obj = Era5Dataset(yparams, yparams.train_data_path, train=True, task=task) + # worker_init_fn = None + # elif data_type == 'cwb': + # dataset_obj = CWBDataset(yparams, yparams.train_data_path, train=True, task=task) + # worker_init_fn = None + # elif data_type == 'era5-cwb-v1': + # #filelist = os.listdir(path=yparams.cwb_data_dir + '/2018') + # #filelist = [name for name in filelist if "2018" in name] + # filelist = [] + # for root, dirs, files in os.walk(yparams.cwb_data_dir): + # for file in files: + # if '2022' not in file: + # filelist.append(file) + # dataset_obj = CWBERA5DatasetV2(yparams, filelist=filelist, chans=list(range(20)), train=True, task=task) + # worker_init_fn = dataset_obj.worker_init_fn + # elif data_type == 'era5-cwb-v2': + # dataset_obj = ZarrDataset(yparams, yparams.train_data_path, train=True) + # worker_init_fn = None + # elif data_type == 'era5-cwb-v3': + # dataset_obj = ZarrDataset(yparams, yparams.train_data_path, train=True) + # #worker_init_fn = dataset_obj.worker_init_fn + # worker_init_fn = None + + # dataset_sampler = InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) + # dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, worker_init_fn=worker_init_fn, **data_loader_kwargs)) + + # img_in_channels = len(yparams.in_channels) #noise + low-res input + # if yparams.add_grid: + # img_in_channels = img_in_channels + yparams.N_grid_channels + + # img_out_channels = len(yparams.out_channels) + + # if use_mean_input: #add it to the args and store_true in yaml file + # img_in_channels = img_in_channels + yparams.N_grid_channels + img_out_channels + + # Construct network. + logger0.info("Constructing network...") + interface_kwargs = dict( + img_resolution=dataset_obj.resolution, + img_channels=dataset_obj.num_channels, + label_dim=dataset_obj.label_dim, + ) # cifar10 + # interface_kwargs = dict(img_resolution=yparams.crop_size_x, img_channels=img_out_channels, img_in_channels=img_in_channels, img_out_channels=img_out_channels, label_dim=0) #weather + + net = construct_class_by_name( + **network_kwargs, **interface_kwargs + ) # subclass of torch.nn.Module + net.train().requires_grad_(True).to(device) + # net = torch.compile(net) + # Distributed data parallel + if dist.world_size > 1: + ddp = DistributedDataParallel( + net, + device_ids=[dist.local_rank], + broadcast_buffers=dist.broadcast_buffers, + output_device=dist.device, + find_unused_parameters=dist.find_unused_parameters, + ) # broadcast_buffers=True for weather data + else: + ddp = net + + if dist.rank == 0: + with torch.no_grad(): + images = torch.zeros( + [batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], + device=device, + ) + # img_clean = torch.zeros([batch_gpu, img_out_channels, net.img_resolution, net.img_resolution], device=device) + # img_lr = torch.zeros([batch_gpu, img_in_channels, net.img_resolution, net.img_resolution], device=device) + sigma = torch.ones([batch_gpu], device=device) + labels = torch.zeros([batch_gpu, net.label_dim], device=device) + # print_module_summary(net, [img_clean, img_lr, sigma, labels], max_nesting=2) + print_module_summary(net, [images, sigma, labels], max_nesting=2) + + # import pdb; pdb.set_trace() + # breakpoint() + + # params = net.parameters() + # print('************************************') + # print('dist.get_rank()', dist.get_rank()) + # print('net.parameters()', net.parameters()) + # for idx, param in enumerate(net.parameters()): + # if idx == 230: + # print(f"Parameter {idx}: {param.stride()}") + # print(f"Parameter {idx}: {param.shape}") + # break + # print('************************************') + + # Setup optimizer. + logger0.info("Setting up optimizer...") + loss_fn = construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss + optimizer = construct_class_by_name( + params=net.parameters(), **optimizer_kwargs + ) # subclass of torch.optim.Optimizer + augment_pipe = ( + construct_class_by_name(**augment_kwargs) + if augment_kwargs is not None + else None + ) # training.augment.AugmentPipe + ema = copy.deepcopy(net).eval().requires_grad_(False) + + # # Import autoresume module + # #print('os.environ', print(os.environ)) + # # sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) + # SUBMIT_SCRIPTS = '/lustre/fsw/adlr/adlr-others/gpeled/adlr-utils/release/cluster-interface/latest' + # sys.path.append(SUBMIT_SCRIPTS) + # #sync autoresums across gpus ... + # AutoResume = None + # try: + # from userlib.auto_resume import AutoResume + # AutoResume.init() + # except ImportError: + # print('AutoResume not imported') + + # Resume training from previous snapshot. + if resume_pkl is not None: + logger0.info(f'Loading network weights from "{resume_pkl}"...') + if dist.rank != 0: + torch.distributed.barrier() # rank 0 goes first + with open_url(resume_pkl, verbose=(dist.rank == 0)) as f: + data = pickle.load(f) + if dist.rank == 0: + torch.distributed.barrier() # other ranks follow + copy_params_and_buffers( + src_module=data["ema"], dst_module=net, require_all=False + ) + copy_params_and_buffers( + src_module=data["ema"], dst_module=ema, require_all=False + ) + del data # conserve memory + if resume_state_dump: + logger0.info(f'Loading training state from "{resume_state_dump}"...') + data = torch.load(resume_state_dump, map_location=torch.device("cpu")) + copy_params_and_buffers( + src_module=data["net"], dst_module=net, require_all=True + ) + optimizer.load_state_dict(data["optimizer_state"]) + del data # conserve memory + + # #check num params per gpu + # with open(f"params_{dist.get_rank()}.txt", "w") as fo: + # logger0.info(net.parameters()) + # for param in net.parameters(): + # logger0.info(param.size()) + # #fo.write(f"{name}\t{param.size()}\n") + # import pdb; pdb.set_trace() + + # Train. + logger0.info(f"Training for {total_kimg} kimg...") + cur_nimg = resume_kimg * 1000 + cur_tick = 0 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - start_time + # dist.update_progress(cur_nimg // 1000, total_kimg) # TODO check if needed + stats_jsonl = None + while True: + + # Accumulate gradients. + optimizer.zero_grad() + for round_idx in range(num_accumulation_rounds): + with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): + + # # Fetch training data: weather + # img_clean, img_lr, labels = next(dataset_iterator) + + # logger0.info(img_clean.shape) + # logger0.info('max-clean', torch.max(img_clean)) + # logger0.info('min-clean', torch.min(img_clean)) + # logger0.info('mean-clean', torch.mean(img_clean)) + # logger0.info('std-clean', torch.std(img_clean)) + # logger0.info(img_lr.shape) + # logger0.info('max-lr', torch.max(img_lr)) + # logger0.info('min-lr', torch.min(img_lr)) + # logger0.info('mean-lr', torch.mean(img_lr)) + # logger0.info('std-lr', torch.std(img_lr)) + # import pdb; pdb.set_trace() + + # # Normalization: weather (normalized already in the dataset) + # img_clean = img_clean.to(device).to(torch.float32).contiguous() #[-4.5, +4.5] + # img_lr = img_lr.to(device).to(torch.float32).contiguous() + # labels = labels.to(device).contiguous() + + # Fetch training data: cifar10 + images, labels = next(dataset_iterator) + # Normalization: cifar10 (normalized already in the dataset) + images = images.to(device).to(torch.float32) / 127.5 - 1 + labels = labels.to(device) + + # loss = loss_fn(net=ddp, img_clean=img_clean, img_lr=img_lr, labels=labels, augment_pipe=augment_pipe) + loss = loss_fn( + net=ddp, images=images, labels=labels, augment_pipe=augment_pipe + ) + report("Loss/loss", loss) + loss.sum().mul(loss_scaling / batch_gpu_total).backward() + + # Update weights. + for g in optimizer.param_groups: + g["lr"] = optimizer_kwargs["lr"] * min( + cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1 + ) + for param in net.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad + ) + optimizer.step() + + # Update EMA. + ema_halflife_nimg = ema_halflife_kimg * 1000 + if ema_rampup_ratio is not None: + ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) + ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) + for p_ema, p_net in zip(ema.parameters(), net.parameters()): + p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) + + # Perform maintenance tasks once per tick. + cur_nimg += batch_size + done = cur_nimg >= total_kimg * 1000 + if ( + (not done) + and (cur_tick != 0) + and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000) + ): + continue + + # Print status line, accumulating the same information in training_stats. + tick_end_time = time.time() + fields = [] + fields += [f"tick {report0('Progress/tick', cur_tick):<5d}"] + fields += [f"kimg {report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] + fields += [ + f"time {format_time(report0('Timing/total_sec', tick_end_time - start_time)):<12s}" + ] + fields += [ + f"sec/tick {report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" + ] + fields += [ + f"sec/kimg {report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" + ] + fields += [ + f"maintenance {report0('Timing/maintenance_sec', maintenance_time):<6.1f}" + ] + fields += [ + f"cpumem {report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + fields += [ + f"gpumem {report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" + ] + fields += [ + f"reserved {report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + # ckpt_dir = run_dir + + # print('AutoResume.termination_requested()', AutoResume.termination_requested()) + # print('AutoResume', AutoResume) + + # if AutoResume.termination_requested(): + # AutoResume.request_resume() + # print("Training terminated. Returning") + # done = True + # #print('dist.get_rank()', dist.get_rank()) + # #with open(os.path.join(os.path.split(ckpt_dir)[0],'resume.txt'), "w") as f: + # with open(os.path.join(ckpt_dir,'resume.txt'), "w") as f: + # f.write(os.path.join(ckpt_dir, f'training-state-{cur_nimg//1000:06d}.pt')) + # print(os.path.join(ckpt_dir, f'training-state-{cur_nimg//1000:06d}.pt')) + # f.close() + # #return 0 + + # dist.print0('*********************************************') + # dist.print0('dist.should_stop()', dist.should_stop()) + # dist.print0('done', done) + # dist.print0('*********************************************') + + # Check for abort. # TODO: check if needed! + # if (not done) and dist.should_stop(): + # done = True + # logger0.info() + # logger0.info("Aborting...") + + # Save network snapshot. + if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): + data = dict( + ema=ema, + loss_fn=loss_fn, + augment_pipe=augment_pipe, + dataset_kwargs=dict(dataset_kwargs), + ) + for key, value in data.items(): + if isinstance(value, torch.nn.Module): + value = copy.deepcopy(value).eval().requires_grad_(False) + if dist.world_size > 1: + check_ddp_consistency(value) + data[key] = value.cpu() + del value # conserve memory + if dist.rank == 0: + with open( + os.path.join(run_dir, f"network-snapshot-{cur_nimg//1000:06d}.pkl"), + "wb", + ) as f: + pickle.dump(data, f) + del data # conserve memory + + # Save full dump of the training state. + if ( + (state_dump_ticks is not None) + and (done or cur_tick % state_dump_ticks == 0) + and cur_tick != 0 + and dist.rank == 0 + ): + # if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and dist.get_rank() == 0: + torch.save( + dict(net=net, optimizer_state=optimizer.state_dict()), + os.path.join(run_dir, f"training-state-{cur_nimg//1000:06d}.pt"), + ) + + # Update logs. + default_collector.update() + if dist.rank == 0: + if stats_jsonl is None: + stats_jsonl = open(os.path.join(run_dir, "stats.jsonl"), "at") + stats_jsonl.write( + json.dumps( + dict( + default_collector.as_dict(), + timestamp=time.time(), + ) + ) + + "\n" + ) + stats_jsonl.flush() + # dist.update_progress(cur_nimg // 1000, total_kimg) # TODO check if needed + + # Update state. + cur_tick += 1 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - tick_end_time + if done: + break + + # Done. + logger0.info() + logger0.info("Exiting...") diff --git a/examples/generative/diffusion/training_stats.py b/examples/generative/diffusion/training_stats.py new file mode 100644 index 0000000000..c9a8a5247d --- /dev/null +++ b/examples/generative/diffusion/training_stats.py @@ -0,0 +1,301 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re + +import numpy as np +import torch +from utils import EasyDict, profiled_function + +# ---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = ( + None # Device to use for multiprocess communication. None = single-process. +) +_sync_called = False # Has _sync() been called yet? +_counters = ( + dict() +) # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = ( + dict() +) # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +# ---------------------------------------------------------------------------- + + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + + +# ---------------------------------------------------------------------------- + + +@profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack( + [ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ] + ) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + + +# ---------------------------------------------------------------------------- + + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + + +# ---------------------------------------------------------------------------- + + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + + def __init__(self, regex=".*", keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros( + [_num_moments], dtype=_counter_dtype + ) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float("nan") + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float("nan") + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `EasyDict`. The contents are as follows: + + EasyDict( + NAME = EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = EasyDict() + for name in self.names(): + stats[name] = EasyDict( + num=self.num(name), mean=self.mean(name), std=self.std(name) + ) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + + +# ---------------------------------------------------------------------------- + + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device("cpu") + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + + +# ---------------------------------------------------------------------------- +# Convenience. + +default_collector = Collector() + +# ---------------------------------------------------------------------------- diff --git a/examples/generative/diffusion/utils.py b/examples/generative/diffusion/utils.py new file mode 100644 index 0000000000..2d0c00d502 --- /dev/null +++ b/examples/generative/diffusion/utils.py @@ -0,0 +1,864 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Miscellaneous utility classes and functions.""" + + +import contextlib +import ctypes +import fnmatch +import glob +import hashlib +import html +import importlib +import inspect +import io +import os +import pickle +import re +import shutil +import sys +import tempfile +import types +import urllib +import urllib.request +import uuid +import warnings +from typing import Any, List, Tuple, Union + +import numpy as np +import requests +import torch + + +class EasyDict(dict): + """ + Convenience class that behaves like a dict but allows access with the attribute + syntax. + """ + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class StackedRandomGenerator: + """ + Wrapper for torch.Generator that allows specifying a different random seed + for each sample in a minibatch. + """ + + def __init__(self, device, seeds): + super().__init__() + self.generators = [ + torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds + ] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack( + [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] + ) + + def randn_like(self, input): + return self.randn( + input.shape, dtype=input.dtype, layout=input.layout, device=input.device + ) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack( + [ + torch.randint(*args, size=size[1:], generator=gen, **kwargs) + for gen in self.generators + ] + ) + + +def parse_int_list(s): + """ + Parse a comma separated list of numbers or ranges and return a list of ints. + Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + """ + if isinstance(s, list): + return s + ranges = [] + range_re = re.compile(r"^(\d+)-(\d+)$") + for p in s.split(","): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) + else: + ranges.append(int(p)) + return ranges + + +# Cache directories +# ------------------------------------------------------------------------------------- + +_dnnlib_cache_dir = None + + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if "DNNLIB_CACHE_DIR" in os.environ: + return os.path.join(os.environ["DNNLIB_CACHE_DIR"], *paths) + if "HOME" in os.environ: + return os.path.join(os.environ["HOME"], ".cache", "dnnlib", *paths) + if "USERPROFILE" in os.environ: + return os.path.join(os.environ["USERPROFILE"], ".cache", "dnnlib", *paths) + return os.path.join(tempfile.gettempdir(), ".cache", "dnnlib", *paths) + + +# Small util functions +# ------------------------------------------------------------------------------------- + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format( + s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 + ) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double, +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """ + Given a type name string (or an object having a __name__ attribute), return + matching Numpy and ctypes types that have the same size in bytes. + """ + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: # TODO remove + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------- + + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """ + Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed). + """ + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [ + (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) + ] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith( + "No module named '" + module_name + "'" + ): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """ + Traverses the object name and returns the last (rightmost) python object. + """ + if obj_name == "": + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """ + Finds the python object with the given name. + """ + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """ + Finds the python object with the given name and calls it as a function. + """ + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """ + Finds the python class with the given name and constructs it with the given + arguments. + """ + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """ + Get the directory path of the module containing the given object name. + """ + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """ + Determine whether the given object is a top-level function, i.e., defined at module + scope using 'def'. + """ + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """ + Return the fully-qualified name of a top-level function. + """ + assert is_top_level_function(obj) + module = obj.__module__ + if module == "__main__": + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + + +def list_dir_recursively_with_ignore( + dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False +) -> List[Tuple[str, str]]: + """ + List all files recursively in a given directory while ignoring given file and + directory names. Returns list of tuples containing both absolute and relative paths. + """ + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """ + Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories. + """ + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """ + Determine whether the given object is a valid URL string. + """ + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith("file://"): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url( + url: str, + cache_dir: str = None, + num_attempts: int = 10, + verbose: bool = True, + return_filename: bool = False, + cache: bool = True, +) -> Any: + """ + Download the given URL and return a binary-mode file object to access the data. + This code handles unusual file:// patterns that + arise on Windows: + + file:///c:/foo.txt + + which would translate to a local '/c:/foo.txt' filename that's + invalid. Drop the forward slash for such pathnames. + + If you touch this code path, you should test it on both Linux and + Windows. + + Some internet resources suggest using urllib.request.url2pathname() but + but that converts forward slashes to backslashes and this causes + its own set of problems. + """ + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match("^[a-z]+://", url): + return url if return_filename else open(url, "rb") + + if url.startswith("file://"): + filename = urllib.parse.urlparse(url).path + if re.match(r"^/[a-zA-Z]:", filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path("downloads") + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [ + html.unescape(link) + for link in content_str.split('"') + if "export=download" in link + ] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError( + "Google Drive download quota exceeded -- please try again later" + ) + + match = re.search( + r'filename="([^"]*)"', + res.headers.get("Content-Disposition", ""), + ) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + safe_name = safe_name[: min(len(safe_name), 128)] + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join( + cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name + ) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) + + +# ---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cpu") + if memory_format is None: + memory_format = torch.contiguous_format + + key = ( + value.shape, + value.dtype, + value.tobytes(), + shape, + dtype, + device, + memory_format, + ) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + + +# ---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + + def nan_to_num( + input, nan=0.0, posinf=None, neginf=None, *, out=None + ): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp( + input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out + ) + + +# ---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +# ---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ("ignore", None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + + +# ---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError( + f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" + ) + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(torch.as_tensor(size), ref_size), + f"Wrong size for dimension {idx}", + ) + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(size, torch.as_tensor(ref_size)), + f"Wrong size for dimension {idx}: expected {ref_size}", + ) + elif size != ref_size: + raise AssertionError( + f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" + ) + + +# ---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + + decorator.__name__ = fn.__name__ + return decorator + + +# ---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__( + self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 + ): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + + +# ---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + + +@torch.no_grad() +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name]) + + +# ---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + + +# ---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + "." + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + + +# ---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + + def pre_hook(_mod, _inputs): + nesting[0] += 1 + + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= { + id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs + } + + # Filter out redundant entries. + if skip_redundant: + entries = [ + e + for e in entries + if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) + ] + + # Construct table. + rows = [ + [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] + ] + rows += [["---"] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = "" if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] + rows += [ + [ + name + (":0" if len(e.outputs) >= 2 else ""), + str(param_size) if param_size else "-", + str(buffer_size) if buffer_size else "-", + (output_shapes + ["-"])[0], + (output_dtypes + ["-"])[0], + ] + ] + for idx in range(1, len(e.outputs)): + rows += [ + [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] + ] + param_total += param_size + buffer_total += buffer_size + rows += [["---"] * len(rows[0])] + rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print( + " ".join( + cell + " " * (width - len(cell)) for cell, width in zip(row, widths) + ) + ) + print() + return outputs + + +# ---------------------------------------------------------------------------- diff --git a/modulus/metrics/diffusion/__init__.py b/modulus/metrics/diffusion/__init__.py new file mode 100644 index 0000000000..50c892f66a --- /dev/null +++ b/modulus/metrics/diffusion/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fid import calculate_fid_from_inception_stats +from .loss import EDMLoss, MixtureLoss, RegressionLoss, ResLoss, VELoss, VPLoss diff --git a/modulus/metrics/diffusion/fid.py b/modulus/metrics/diffusion/fid.py new file mode 100644 index 0000000000..f2bfb24cf6 --- /dev/null +++ b/modulus/metrics/diffusion/fid.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from scipy.linalg import sqrtm + + +def calculate_fid_from_inception_stats( + mu: np.ndarray, sigma: np.ndarray, mu_ref: np.ndarray, sigma_ref: np.ndarray +) -> float: + """ + Calculate the Fréchet Inception Distance (FID) between two sets + of Inception statistics. + + The Fréchet Inception Distance is a measure of the similarity between two datasets + based on their Inception features (mu and sigma). It is commonly used to evaluate + the quality of generated images in generative models. + + Parameters + ---------- + mu: np.ndarray: + Mean of Inception statistics for the generated dataset. + sigma: np.ndarray: + Covariance matrix of Inception statistics for the generated dataset. + mu_ref: np.ndarray + Mean of Inception statistics for the reference dataset. + sigma_ref: np.ndarray + Covariance matrix of Inception statistics for the reference dataset. + + Returns + ------- + float + The Fréchet Inception Distance (FID) between the two datasets. + """ + m = np.square(mu - mu_ref).sum() + s, _ = sqrtm(np.dot(sigma, sigma_ref), disp=False) + fid = m + np.trace(sigma + sigma_ref - s * 2) + return float(np.real(fid)) diff --git a/modulus/metrics/diffusion/loss.py b/modulus/metrics/diffusion/loss.py new file mode 100644 index 0000000000..f73a21b5e6 --- /dev/null +++ b/modulus/metrics/diffusion/loss.py @@ -0,0 +1,554 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Loss functions used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +from typing import Callable, Optional, Union + +import torch + + +class VPLoss: + """ + Loss function corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + beta_d: float, optional + Coefficient for the diffusion process, by default 19.9. + beta_min: float, optional + Minimum bound, by defaults 0.1. + epsilon_t: float, optional + Small positive value, by default 1e-5. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + """ + + def __init__( + self, beta_d: float = 19.9, beta_min: float = 0.1, epsilon_t: float = 1e-5 + ): + self.beta_d = beta_d + self.beta_min = beta_min + self.epsilon_t = epsilon_t + + def __call__( + self, + net: torch.nn.Module, + images: torch.Tensor, + labels: torch.Tensor, + augment_pipe: Optional[Callable] = None, + ): + """ + Calculate and return the loss corresponding to the variance preserving (VP) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'epsilon_t' and random values. The calculated loss is weighted based on the + inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + def sigma( + self, t: Union[float, torch.Tensor] + ): # NOTE: also exists in preconditioning + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + +class VELoss: + """ + Loss function corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__(self, sigma_min: float = 0.02, sigma_max: float = 100.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance exploding (VE) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'sigma_min' and 'sigma_max' and random values. The calculated loss is weighted + based on the inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +# class EDMLoss: +# """ +# Loss function proposed in the EDM paper. + +# Parameters +# ---------- +# P_mean: float, optional +# Mean value for `sigma` computation, by default -1.2. +# P_std: float, optional: +# Standard deviation for `sigma` computation, by default 1.2. +# sigma_data: float, optional +# Standard deviation for data, by default 0.5. + +# Note +# ---- +# Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the +# design space of diffusion-based generative models. Advances in Neural Information +# Processing Systems, 35, pp.26565-26577. +# """ + +# def __init__( +# self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 +# ): +# self.P_mean = P_mean +# self.P_std = P_std +# self.sigma_data = sigma_data + +# def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): +# rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) +# sigma = (rnd_normal * self.P_std + self.P_mean).exp() +# weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + +# # augment for conditional generaiton +# img_tot = torch.cat((img_clean, img_lr), dim=1) +# y_tot, augment_labels = ( +# augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) +# ) +# y = y_tot[:, : img_clean.shape[1], :, :] +# y_lr = y_tot[:, img_clean.shape[1] :, :, :] + +# n = torch.randn_like(y) * sigma +# D_yn = net(y + n, y_lr, sigma, labels, augment_labels=augment_labels) +# loss = weight * ((D_yn - y) ** 2) +# return loss + + +class EDMLoss: + """ + Loss function proposed in the EDM paper. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, images, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class RegressionLoss: + """ + Regression loss function for the U-Net for deterministic predictions. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss for the U-Net for deterministic predictions. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + + return loss + + +class MixtureLoss: + """ + Mixture loss function for regression and denoising score matching. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss for regression and denoising score matching. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = ( + rnd_normal * self.P_std + self.P_mean + ).exp() # in the range [0,2], but signal is in [-4, 7] + den_weight = (sigma**2 + self.sigma_data**2) / ( + sigma * self.sigma_data + ) ** 2 # in the range [5,2000] with high prob. if randn in [-1,+1] + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + n = torch.randn_like(y) * sigma + latent = y + n + D_yn = net(latent, y_lr, sigma, labels, augment_labels=augment_labels) + R_yn = net( + latent * 0.0, y_lr, sigma, labels, augment_labels=augment_labels + ) # regression loss, zero stochasticity + + reg_weight = torch.tensor(5.0).cuda() + loss = den_weight * ((D_yn - y) ** 2) + reg_weight * ((R_yn - y) ** 2) + + return loss + + +class ResLoss: + """ + Mixture loss function for denoising score matching. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = 0.0, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + with torch.no_grad(): + resume_state_dump = "/training-state-042650.pt" + data = torch.load(resume_state_dump, map_location=torch.device("cpu")) + self.unet = data["net"].cuda() # TODO better handling of device + # misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss for denoising score matching. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # augment for conditional generaiton + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + # form residual + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr, + sigma, + labels, + augment_labels=augment_labels, + ) + y = y - y_mean + + latent = y + torch.randn_like(y) * sigma + D_yn = net(latent, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + + return loss diff --git a/modulus/models/diffusion/__init__.py b/modulus/models/diffusion/__init__.py new file mode 100644 index 0000000000..dda7ccbff8 --- /dev/null +++ b/modulus/models/diffusion/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa +from .utils import weight_init +from .layers import ( + AttentionOp, + Conv2d, + FourierEmbedding, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from .song_unet import SongUNet +from .dhariwal_unet import DhariwalUNet +from .unet import UNet +from .preconditioning import EDMPrecond, VEPrecond, VPPrecond, iDDPMPrecond diff --git a/modulus/models/diffusion/dhariwal_unet.py b/modulus/models/diffusion/dhariwal_unet.py new file mode 100644 index 0000000000..2f5213ed50 --- /dev/null +++ b/modulus/models/diffusion/dhariwal_unet.py @@ -0,0 +1,257 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +from torch.nn.functional import silu + +from modulus.models.diffusion import ( + Conv2d, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from modulus.models.meta import ModelMetaData +from modulus.models.module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "DhariwalUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class DhariwalUNet(Module): + """ + Reimplementation of the ADM architecture, a U-Net variant, with optional + self-attention. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters: + ----------- + img_resolution : int + The resolution of the input/output image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 192. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,3,4]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 3. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [32, 16, 8]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + + Note: + ----- + Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image + synthesis. Advances in neural information processing systems, 34, pp.8780-8794. + + Note: + ----- + Equivalent to the original implementation by Dhariwal and Nichol, available at + https://github.com/openai/guided-diffusion + + Example: + -------- + >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + """ + + def __init__( + self, + img_resolution: int, + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 192, + channel_mult: List[int] = [1, 2, 3, 4], + channel_mult_emb: int = 4, + num_blocks: int = 3, + attn_resolutions: List[int] = [32, 16, 8], + dropout: float = 0.10, + label_dropout: float = 0.0, + ): + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + init = dict( + init_mode="kaiming_uniform", + init_weight=np.sqrt(1 / 3), + init_bias=np.sqrt(1 / 3), + ) + init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) + block_kwargs = dict( + emb_channels=emb_channels, + channels_per_head=64, + dropout=dropout, + init=init, + init_zero=init_zero, + ) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=model_channels) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=model_channels, + bias=False, + **init_zero, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=model_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + self.map_label = ( + Linear( + in_features=label_dim, + out_features=emb_channels, + bias=False, + init_mode="kaiming_normal", + init_weight=np.sqrt(label_dim), + ) + if label_dim + else None + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + skips = [block.out_channels for block in self.enc.values()] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + self.out_norm = GroupNorm(num_channels=cout) + self.out_conv = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + # Mapping. + emb = self.map_noise(noise_labels) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = self.map_layer1(emb) + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp) + emb = silu(emb) + + # Encoder. + skips = [] + for block in self.enc.values(): + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + for block in self.dec.values(): + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + x = self.out_conv(silu(self.out_norm(x))) + return x diff --git a/modulus/models/diffusion/layers.py b/modulus/models/diffusion/layers.py new file mode 100644 index 0000000000..ce8caf0153 --- /dev/null +++ b/modulus/models/diffusion/layers.py @@ -0,0 +1,541 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architecture layers used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from typing import Any, Dict, List + +import numpy as np +import torch +from torch.nn.functional import silu + +from modulus.models.diffusion import weight_init + + +class Linear(torch.nn.Module): + """ + A fully connected (dense) layer implementation. The layer's weights and biases can + be initialized using custom initialization strategies like "kaiming_normal", + and can be further scaled by factors `init_weight` and `init_bias`. + + Parameters + ---------- + in_features : int + Size of each input sample. + out_features : int + Size of each output sample. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an additive + bias. By default True. + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + init_mode: str = "kaiming_normal", + init_weight: int = 1, + init_bias: int = 0, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter( + weight_init([out_features, in_features], **init_kwargs) * init_weight + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) + if bias + else None + ) + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t() + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype)) + return x + + +class Conv2d(torch.nn.Module): + """ + A custom 2D convolutional layer implementation with support for up-sampling, + down-sampling, and custom weight and bias initializations. The layer's weights + and biases canbe initialized using custom initialization strategies like + "kaiming_normal", and can be further scaled by factors `init_weight` and + `init_bias`. + + Parameters + ---------- + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels produced by the convolution. + kernel : int + Size of the convolving kernel. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an + additive bias. By default True. + up : bool, optional + Whether to perform up-sampling. By default False. + down : bool, optional + Whether to perform down-sampling. By default False. + resample_filter : List[int], optional + Filter to be used for resampling. By default [1, 1]. + fused_resample : bool, optional + If True, performs fused up-sampling and convolution or fused down-sampling + and convolution. By default False. + init_mode : str, optional (default="kaiming_normal") + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1.0. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0.0. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + bias: bool = True, + up: bool = False, + down: bool = False, + resample_filter: List[int] = [1, 1], + fused_resample: bool = False, + init_mode: str = "kaiming_normal", + init_weight: float = 1.0, + init_bias: float = 0.0, + ): + if up and down: + raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + self.fused_resample = fused_resample + init_kwargs = dict( + mode=init_mode, + fan_in=in_channels * kernel * kernel, + fan_out=out_channels * kernel * kernel, + ) + self.weight = ( + torch.nn.Parameter( + weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) + * init_weight + ) + if kernel + else None + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) + if kernel and bias + else None + ) + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() + self.register_buffer("resample_filter", f if up or down else None) + + def forward(self, x): + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + f = ( + self.resample_filter.to(x.dtype) + if self.resample_filter is not None + else None + ) + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.fused_resample and self.up and w is not None: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=max(f_pad - w_pad, 0), + ) + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.fused_resample and self.down and w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) + else: + if self.up: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if self.down: + x = torch.nn.functional.conv2d( + x, + f.tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None: + x = x.add_(b.reshape(1, -1, 1, 1)) + return x + + +class GroupNorm(torch.nn.Module): + """ + A custom Group Normalization layer implementation. + + Group Normalization (GN) divides the channels of the input tensor into groups and + normalizes the features within each group independently. It does not require the + batch size as in Batch Normalization, making itsuitable for batch sizes of any size + or even for batch-free scenarios. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional + Desired number of groups to divide the input channels, by default 32. + This might be adjusted based on the `min_channels_per_group`. + min_channels_per_group : int, optional + Minimum channels required per group. This ensures that no group has fewer + channels than this number. By default 4. + eps : float, optional + A small number added to the variance to prevent division by zero, by default + 1e-5. + + Notes + ----- + If `num_channels` is not divisible by `num_groups`, the actual number of groups + might be adjusted to satisfy the `min_channels_per_group` condition. + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + ): + super().__init__() + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + + def forward(self, x): + x = torch.nn.functional.group_norm( + x, + num_groups=self.num_groups, + weight=self.weight.to(x.dtype), + bias=self.bias.to(x.dtype), + eps=self.eps, + ) + return x + + +class AttentionOp(torch.autograd.Function): + """ + Attention weight computation, i.e., softmax(Q^T * K). + Performs all computation using FP32, but uses the original datatype for + inputs/outputs/gradients to conserve memory. + """ + + @staticmethod + def forward(ctx, q, k): + w = ( + torch.einsum( + "ncq,nck->nqk", + q.to(torch.float32), + (k / np.sqrt(k.shape[1])).to(torch.float32), + ) + .softmax(dim=2) + .to(q.dtype) + ) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + db = torch._softmax_backward_data( + grad_output=dw.to(torch.float32), + output=w.to(torch.float32), + dim=2, + input_dtype=torch.float32, + ) + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( + q.dtype + ) / np.sqrt(k.shape[1]) + dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( + k.dtype + ) / np.sqrt(k.shape[1]) + return dq, dk + + +class UNetBlock(torch.nn.Module): + """ + Unified U-Net block with optional up/downsampling and self-attention. Represents + the union of all features employed by the DDPM++, NCSN++, and ADM architectures. + + Parameters: + ----------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + emb_channels : int + Number of embedding channels. + up : bool, optional + If True, applies upsampling in the forward pass. By default False. + down : bool, optional + If True, applies downsampling in the forward pass. By default False. + attention : bool, optional + If True, enables the self-attention mechanism in the block. By default False. + num_heads : int, optional + Number of attention heads. If None, defaults to `out_channels // 64`. + channels_per_head : int, optional + Number of channels per attention head. By default 64. + dropout : float, optional + Dropout probability. By default 0.0. + skip_scale : float, optional + Scale factor applied to skip connections. By default 1.0. + eps : float, optional + Epsilon value used for normalization layers. By default 1e-5. + resample_filter : List[int], optional + Filter for resampling layers. By default [1, 1]. + resample_proj : bool, optional + If True, resampling projection is enabled. By default False. + adaptive_scale : bool, optional + If True, uses adaptive scaling in the forward pass. By default True. + init : dict, optional + Initialization parameters for convolutional and linear layers. + init_zero : dict, optional + Initialization parameters with zero weights for certain layers. By default + {'init_weight': 0}. + init_attn : dict, optional + Initialization parameters specific to attention mechanism layers. + Defaults to 'init' if not provided. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + up: bool = False, + down: bool = False, + attention: bool = False, + num_heads: int = None, + channels_per_head: int = 64, + dropout: float = 0.0, + skip_scale: float = 1.0, + eps: float = 1e-5, + resample_filter: List[int] = [1, 1], + resample_proj: bool = False, + adaptive_scale: bool = True, + init: Dict[str, Any] = dict(), + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.num_heads = ( + 0 + if not attention + else num_heads + if num_heads is not None + else out_channels // channels_per_head + ) + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + + self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.conv0 = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=3, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + self.affine = Linear( + in_features=emb_channels, + out_features=out_channels * (2 if adaptive_scale else 1), + **init, + ) + self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + self.conv1 = Conv2d( + in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + ) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels != in_channels else 0 + self.skip = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=kernel, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + + if self.num_heads: + self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.qkv = Conv2d( + in_channels=out_channels, + out_channels=out_channels * 3, + kernel=1, + **(init_attn if init_attn is not None else init), + ) + self.proj = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=1, + **init_zero, + ) + + def forward(self, x, emb): + orig = x + x = self.conv0(silu(self.norm0(x))) + + params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = silu(self.norm1(x.add_(params))) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) + ) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(2) + ) + w = AttentionOp.apply(q, k) + a = torch.einsum("nqk,nck->ncq", w, v) + x = self.proj(a.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + return x + + +class PositionalEmbedding(torch.nn.Module): + """ + A module for generating positional embeddings based on timesteps. + This embedding technique is employed in the DDPM++ and ADM architectures. + + Parameters: + ----------- + num_channels : int + Number of channels for the embedding. + max_positions : int, optional + Maximum number of positions for the embeddings, by default 10000. + endpoint : bool, optional + If True, the embedding considers the endpoint. By default False. + + """ + + def __init__( + self, num_channels: int, max_positions: int = 10000, endpoint: bool = False + ): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange( + start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device + ) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +class FourierEmbedding(torch.nn.Module): + """ + Generates Fourier embeddings for timesteps, primarily used in the NCSN++ + architecture. + + This class generates embeddings by first multiplying input tensor `x` and + internally stored random frequencies, and then concatenating the cosine and sine of + the resultant. + + Parameters: + ----------- + num_channels : int + The number of channels in the embedding. The final embedding size will be + 2 * num_channels because of concatenation of cosine and sine results. + scale : int, optional + A scale factor applied to the random frequencies, controlling their range + and thereby the frequency of oscillations in the embedding space. By default 16. + """ + + def __init__(self, num_channels: int, scale: int = 16): + super().__init__() + self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/modulus/models/diffusion/preconditioning.py b/modulus/models/diffusion/preconditioning.py new file mode 100644 index 0000000000..001e33ddfe --- /dev/null +++ b/modulus/models/diffusion/preconditioning.py @@ -0,0 +1,692 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preconditioning schemes used in the paper"Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from typing import List, Union + +import numpy as np +import torch + +from modulus.models.diffusion import DhariwalUNet, SongUNet # noqa: F401 for globals + + +class VPPrecond(torch.nn.Module): + """ + Preconditioning corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + beta_d : float + Extent of the noise level schedule, by default 19.9. + beta_min : float + Initial slope of the noise level schedule, by default 0.1. + M : int + Original number of timesteps in the DDPM formulation, by default 1000. + epsilon_t : float + Minimum t-value used during training, by default 1e-5. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + epsilon_t: float = 1e-5, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t: Union[float, torch.Tensor]): + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma: Union[float, torch.Tensor]): + """ + Compute the inverse of the sigma function for a given sigma. + + This function effectively calculates t from a given sigma(t) based on the + parameters `beta_d` and `beta_min`. + + Parameters + ---------- + sigma : Union[float, torch.Tensor] + The sigma(t) value or set of sigma(t) values for which to compute the + inverse. + + Returns + ------- + torch.Tensor + The computed t value(s) corresponding to the provided sigma(t). + """ + sigma = torch.as_tensor(sigma) + return ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +class VEPrecond(torch.nn.Module): + """ + Preconditioning corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +class iDDPMPrecond(torch.nn.Module): + """ + Preconditioning corresponding to the improved DDPM (iDDPM) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + C_1 : float + Timestep adjustment at low noise levels., by default 0.001. + C_2 : float + Timestep adjustment at high noise levels., by default 0.008. + M: int + Original number of timesteps in the DDPM formulation, by default 1000. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion + probabilistic models. In International Conference on Machine Learning + (pp. 8162-8171). PMLR. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + C_1=0.001, + C_2=0.008, + M=1000, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels * 2, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) + / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) + - 1 + ).sqrt() + self.register_buffer("u", u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = ( + self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + ) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + """ + Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. + + Parameters + ---------- + j : Union[int, torch.Tensor] + The timestep or set of timesteps for which to compute alpha_bar(j). + + Returns + ------- + torch.Tensor + The computed alpha_bar(j) value(s). + """ + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + """ + Round the provided sigma value(s) to the nearest value(s) in a + pre-defined set `u`. + + Parameters + ---------- + sigma : Union[float, list, torch.Tensor] + The sigma value(s) to round. + return_index : bool, optional + Whether to return the index/indices of the rounded value(s) in `u` instead + of the rounded value(s) themselves, by default False. + + Returns + ------- + torch.Tensor + The rounded sigma value(s) or their index/indices in `u`, depending on the + value of `return_index`. + """ + sigma = torch.as_tensor(sigma) + index = torch.cdist( + sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), + self.u.reshape(1, -1, 1), + ).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + + +class EDMPrecond(torch.nn.Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +class EDMPrecondSR(torch.nn.Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_in_channels, + img_out_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__() + self.img_resolution = img_resolution + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward( + self, x, img_lr, sigma, class_labels=None, force_fp32=False, **model_kwargs + ): + + # Concatenate input channels + x = torch.cat((x, img_lr), dim=1) + + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + # Skip connection - for SR there's size mismatch bwtween input and output + x = x[:, 0 : self.img_out_channels, :, :] + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + See EDMPrecond.round_sigma + """ + return EDMPrecond.round_sigma(sigma) diff --git a/modulus/models/diffusion/song_unet.py b/modulus/models/diffusion/song_unet.py new file mode 100644 index 0000000000..c31d7cf132 --- /dev/null +++ b/modulus/models/diffusion/song_unet.py @@ -0,0 +1,360 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +from torch.nn.functional import silu + +from modulus.models.diffusion import ( + Conv2d, + FourierEmbedding, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from modulus.models.meta import ModelMetaData +from modulus.models.module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "SongUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class SongUNet(Module): + """ + Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with + optional self-attention,embeddings, and encoder-decoder components. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters: + ----------- + img_resolution : int + The resolution of the input/output image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note: + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example: + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: int, + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [16], + dropout: float = 0.10, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + ): + + valid_embedding_types = ["fourier", "positional", "zero"] + if embedding_type not in valid_embedding_types: + raise ValueError( + f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." + ) + + valid_encoder_types = ["standard", "skip", "residual"] + if encoder_type not in valid_encoder_types: + raise ValueError( + f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." + ) + + valid_decoder_types = ["standard", "skip"] + if decoder_type not in valid_decoder_types: + raise ValueError( + f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." + ) + + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + self.embedding_type = embedding_type + emb_channels = model_channels * channel_mult_emb + self.emb_channels = emb_channels + noise_channels = model_channels * channel_mult_noise + init = dict(init_mode="xavier_uniform") + init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) + init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) + block_kwargs = dict( + emb_channels=emb_channels, + num_heads=1, + dropout=dropout, + skip_scale=np.sqrt(0.5), + eps=1e-6, + resample_filter=resample_filter, + resample_proj=True, + adaptive_scale=False, + init=init, + init_zero=init_zero, + init_attn=init_attn, + ) + + # Mapping. + if self.embedding_type != "zero": + self.map_noise = ( + PositionalEmbedding(num_channels=noise_channels, endpoint=True) + if embedding_type == "positional" + else FourierEmbedding(num_channels=noise_channels) + ) + self.map_label = ( + Linear(in_features=label_dim, out_features=noise_channels, **init) + if label_dim + else None + ) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=noise_channels, + bias=False, + **init, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=noise_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + caux = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + if encoder_type == "skip": + self.enc[f"{res}x{res}_aux_down"] = Conv2d( + in_channels=caux, + out_channels=caux, + kernel=0, + down=True, + resample_filter=resample_filter, + ) + self.enc[f"{res}x{res}_aux_skip"] = Conv2d( + in_channels=caux, out_channels=cout, kernel=1, **init + ) + if encoder_type == "residual": + self.enc[f"{res}x{res}_aux_residual"] = Conv2d( + in_channels=caux, + out_channels=cout, + kernel=3, + down=True, + resample_filter=resample_filter, + fused_resample=True, + **init, + ) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = res in attn_resolutions + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + skips = [ + block.out_channels for name, block in self.enc.items() if "aux" not in name + ] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = idx == num_blocks and res in attn_resolutions + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + if decoder_type == "skip" or level == 0: + if decoder_type == "skip" and level < len(channel_mult) - 1: + self.dec[f"{res}x{res}_aux_up"] = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=0, + up=True, + resample_filter=resample_filter, + ) + self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( + num_channels=cout, eps=1e-6 + ) + self.dec[f"{res}x{res}_aux_conv"] = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros((noise_labels.shape[0], self.emb_channels)).cuda() + + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + else: + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + return aux diff --git a/modulus/models/diffusion/unet.py b/modulus/models/diffusion/unet.py new file mode 100644 index 0000000000..74466e8518 --- /dev/null +++ b/modulus/models/diffusion/unet.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from modulus.models.meta import ModelMetaData +from modulus.models.module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "UNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class UNet(Module): + """ + U-Net architecture. + + Parameters: + ----------- + img_resolution : int + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'DhariwalUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + + Note: + ----- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + label_dim=0, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) + + def forward( + self, x, img_lr, sigma, class_labels=None, force_fp32=False, **model_kwargs + ): + + # SR: concatenate input channels + if img_lr is None: + x = x + else: + x = torch.cat((x, img_lr), dim=1) + + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_skip = 0.0 * c_skip + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_out = torch.ones_like(c_out) + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_in = torch.ones_like(c_in) + c_noise = sigma.log() / 4 + c_noise = 0.0 * c_noise + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + # skip connection - for SR there's size mismatch bwtween input and output + x = x[:, 0 : self.img_out_channels, :, :] + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) diff --git a/modulus/models/diffusion/utils.py b/modulus/models/diffusion/utils.py new file mode 100644 index 0000000000..870ad2fc82 --- /dev/null +++ b/modulus/models/diffusion/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') diff --git a/pyproject.toml b/pyproject.toml index da20f408b0..1021326278 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,9 @@ fixable = ["I"] # and `S311` (random number generators) ignore = ["E501", "S311"] +# Exclude the examples folder +exclude = ["examples"] + [tool.ruff.per-file-ignores] # Ignore `F401` (import violations) in all `__init__.py` files, and in `docs/*.py`. "__init__.py" = ["F401"] diff --git a/test/metrics/diffusion/test_fid.py b/test/metrics/diffusion/test_fid.py new file mode 100644 index 0000000000..242588098c --- /dev/null +++ b/test/metrics/diffusion/test_fid.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pytest + +from modulus.metrics.diffusion import calculate_fid_from_inception_stats + + +def test_fid_calculation(): + mu = np.array([1.0, 2.0]) + sigma = np.array([[1.0, 0.5], [0.5, 2.0]]) + mu_ref = np.array([0.0, 1.0]) + sigma_ref = np.array([[2.0, 0.3], [0.3, 1.5]]) + + fid = calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref) + expected_fid = 2.234758220608337 + + assert pytest.approx(fid, abs=1e-4) == expected_fid diff --git a/test/metrics/diffusion/test_losses.py b/test/metrics/diffusion/test_losses.py new file mode 100644 index 0000000000..9b1559726e --- /dev/null +++ b/test/metrics/diffusion/test_losses.py @@ -0,0 +1,265 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from modulus.metrics.diffusion import ( + EDMLoss, + VELoss, + VPLoss, +) + +# VPLoss tests + + +def test_vploss_initialization(): + loss_func = VPLoss() + assert loss_func.beta_d == 19.9 + assert loss_func.beta_min == 0.1 + assert loss_func.epsilon_t == 1e-5 + + loss_func = VPLoss(beta_d=10.0, beta_min=0.5, epsilon_t=1e-4) + assert loss_func.beta_d == 10.0 + assert loss_func.beta_min == 0.5 + assert loss_func.epsilon_t == 1e-4 + + +def test_sigma_method(): + loss_func = VPLoss() + + # Scalar input + sigma_val = loss_func.sigma(1.0) + assert isinstance(sigma_val, torch.Tensor) + assert sigma_val.item() > 0 + + # Tensor input + t = torch.tensor([1.0, 2.0]) + sigma_vals = loss_func.sigma(t) + assert sigma_vals.shape == t.shape + + +def fake_net(y, sigma, labels, augment_labels=None): + return torch.tensor([1.0]) + + +def test_call_method_vp(): + loss_func = VPLoss() + + images = torch.tensor([[[[1.0]]]]) + labels = None + + # Without augmentation + loss_value = loss_func(fake_net, images, labels) + assert isinstance(loss_value, torch.Tensor) + + # With augmentation + def mock_augment_pipe(imgs): + return imgs, None + + loss_value_with_augmentation = loss_func( + fake_net, images, labels, mock_augment_pipe + ) + assert isinstance(loss_value_with_augmentation, torch.Tensor) + + +# VELoss tests + + +def test_veloss_initialization(): + loss_func = VELoss() + assert loss_func.sigma_min == 0.02 + assert loss_func.sigma_max == 100.0 + + loss_func = VELoss(sigma_min=0.01, sigma_max=50.0) + assert loss_func.sigma_min == 0.01 + assert loss_func.sigma_max == 50.0 + + +def test_call_method_ve(): + loss_func = VELoss() + + images = torch.tensor([[[[1.0]]]]) + labels = None + + # Without augmentation + loss_value = loss_func(fake_net, images, labels) + assert isinstance(loss_value, torch.Tensor) + + # With augmentation + def mock_augment_pipe(imgs): + return imgs, None + + loss_value_with_augmentation = loss_func( + fake_net, images, labels, mock_augment_pipe + ) + assert isinstance(loss_value_with_augmentation, torch.Tensor) + + +# EDMLoss tests + + +def test_edmloss_initialization(): + loss_func = EDMLoss() + assert loss_func.P_mean == -1.2 + assert loss_func.P_std == 1.2 + assert loss_func.sigma_data == 0.5 + + loss_func = EDMLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) + assert loss_func.P_mean == -2.0 + assert loss_func.P_std == 2.0 + assert loss_func.sigma_data == 0.3 + + +def test_call_method_edm(): + loss_func = EDMLoss() + + img = torch.tensor([[[[1.0]]]]) + labels = None + + # Without augmentation + loss_value = loss_func(fake_net, img, labels) + assert isinstance(loss_value, torch.Tensor) + + # With augmentation + def mock_augment_pipe(imgs): + return imgs, None + + loss_value_with_augmentation = loss_func(fake_net, img, labels, mock_augment_pipe) + assert isinstance(loss_value_with_augmentation, torch.Tensor) + + +# RegressionLoss tests + + +# def test_regressionloss_initialization(): +# loss_func = RegressionLoss() +# assert loss_func.P_mean == -1.2 +# assert loss_func.P_std == 1.2 +# assert loss_func.sigma_data == 0.5 + +# loss_func = RegressionLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) +# assert loss_func.P_mean == -2.0 +# assert loss_func.P_std == 2.0 +# assert loss_func.sigma_data == 0.3 + + +# def fake_net(input, y_lr, sigma, labels, augment_labels=None): +# return torch.tensor([1.0]) + + +# def test_call_method(): +# loss_func = RegressionLoss() + +# img_clean = torch.tensor([[[[1.0]]]]) +# img_lr = torch.tensor([[[[0.5]]]]) +# labels = None + +# # Without augmentation +# loss_value = loss_func(fake_net, img_clean, img_lr, labels) +# assert isinstance(loss_value, torch.Tensor) + +# # With augmentation +# def mock_augment_pipe(imgs): +# return imgs, None + +# loss_value_with_augmentation = loss_func( +# fake_net, img_clean, img_lr, labels, mock_augment_pipe +# ) +# assert isinstance(loss_value_with_augmentation, torch.Tensor) + + +# MixtureLoss tests + + +# def test_mixtureloss_initialization(): +# loss_func = MixtureLoss() +# assert loss_func.P_mean == -1.2 +# assert loss_func.P_std == 1.2 +# assert loss_func.sigma_data == 0.5 + +# loss_func = MixtureLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) +# assert loss_func.P_mean == -2.0 +# assert loss_func.P_std == 2.0 +# assert loss_func.sigma_data == 0.3 + + +# def fake_net(latent, y_lr, sigma, labels, augment_labels=None): +# return torch.tensor([1.0]) + + +# def test_call_method(): +# loss_func = MixtureLoss() + +# img_clean = torch.tensor([[[[1.0]]]]) +# img_lr = torch.tensor([[[[0.5]]]]) +# labels = None + +# # Without augmentation +# loss_value = loss_func(fake_net, img_clean, img_lr, labels) +# assert isinstance(loss_value, torch.Tensor) + +# # With augmentation +# def mock_augment_pipe(imgs): +# return imgs, None + +# loss_value_with_augmentation = loss_func( +# fake_net, img_clean, img_lr, labels, mock_augment_pipe +# ) +# assert isinstance(loss_value_with_augmentation, torch.Tensor) + + +# ResLoss tests + + +# def test_resloss_initialization(): +# # Mock the model loading +# ResLoss.unet = torch.nn.Linear(1, 1).cuda() + +# loss_func = ResLoss() +# assert loss_func.P_mean == 0.0 +# assert loss_func.P_std == 1.2 +# assert loss_func.sigma_data == 0.5 + +# loss_func = ResLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) +# assert loss_func.P_mean == -2.0 +# assert loss_func.P_std == 2.0 +# assert loss_func.sigma_data == 0.3 + + +# def fake_net(latent, y_lr, sigma, labels, augment_labels=None): +# return torch.tensor([1.0]) + + +# def test_call_method(): +# # Mock the model loading +# ResLoss.unet = torch.nn.Linear(1, 1).cuda() + +# loss_func = ResLoss() + +# img_clean = torch.tensor([[[[1.0]]]]) +# img_lr = torch.tensor([[[[0.5]]]]) +# labels = None + +# # Without augmentation +# loss_value = loss_func(fake_net, img_clean, img_lr, labels) +# assert isinstance(loss_value, torch.Tensor) + +# # With augmentation +# def mock_augment_pipe(imgs): +# return imgs, None + +# loss_value_with_augmentation = loss_func( +# fake_net, img_clean, img_lr, labels, mock_augment_pipe +# ) +# assert isinstance(loss_value_with_augmentation, torch.Tensor) diff --git a/test/models/data/ddmpp_unet_output.pth b/test/models/data/ddmpp_unet_output.pth new file mode 100644 index 0000000000..423e547dd6 Binary files /dev/null and b/test/models/data/ddmpp_unet_output.pth differ diff --git a/test/models/data/dhariwal_unet_output.pth b/test/models/data/dhariwal_unet_output.pth new file mode 100644 index 0000000000..aa8d14cd33 Binary files /dev/null and b/test/models/data/dhariwal_unet_output.pth differ diff --git a/test/models/data/ncsnpp_unet_output.pth b/test/models/data/ncsnpp_unet_output.pth new file mode 100644 index 0000000000..8fdff993e4 Binary files /dev/null and b/test/models/data/ncsnpp_unet_output.pth differ diff --git a/test/models/diffusion/test_dhariwal_unet.py b/test/models/diffusion/test_dhariwal_unet.py new file mode 100644 index 0000000000..db934d489b --- /dev/null +++ b/test/models/diffusion/test_dhariwal_unet.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import sys + +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common + +from modulus.models.diffusion import DhariwalUNet as UNet + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_dhariwal_unet_forward(device): + torch.manual_seed(0) + model = UNet(img_resolution=64, in_channels=2, out_channels=2).to(device) + input_image = torch.ones([1, 2, 64, 64]).to(device) + noise_labels = noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + + assert common.validate_forward_accuracy( + model, + (input_image, noise_labels, class_labels), + file_name="dhariwal_unet_output.pth", + atol=1e-3, + ) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_dhariwal_unet_constructor(device): + """Test the Dhariwal UNet constructor options""" + + img_resolution = 16 + in_channels = 2 + out_channels = 2 + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + ).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_dhariwal_unet_optims(device): + """Test Dhariwal UNet optimizations""" + + def setup_model(): + model = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + ).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + return model, [input_image, noise_labels, class_labels] + + # Ideally always check graphs first + model, invar = setup_model() + assert common.validate_cuda_graphs(model, (*invar,)) + + # Check JIT + model, invar = setup_model() + assert common.validate_jit(model, (*invar,)) + # Check AMP + model, invar = setup_model() + assert common.validate_amp(model, (*invar,)) + # Check Combo + model, invar = setup_model() + assert common.validate_combo_optims(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_dhariwal_unet_checkpoint(device): + """Test Dhariwal UNet checkpoint save/load""" + # Construct FNO models + model_1 = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + ).to(device) + + model_2 = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + ).to(device) + # This test doesn't like the model outputs to be the same. + # Change the bias in the last layer of the second model as a hack + # Because this model is initialized with all zeros + with torch.no_grad(): + model_2.out_conv.bias += 1 + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + assert common.validate_checkpoint( + model_1, model_2, (*[input_image, noise_labels, class_labels],) + ) + + +@common.check_ort_version() +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_dhariwal_unet_deploy(device): + """Test Dhariwal UNet deployment support""" + model = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + ).to(device) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + assert common.validate_onnx_export( + model, (*[input_image, noise_labels, class_labels],) + ) + assert common.validate_onnx_runtime( + model, (*[input_image, noise_labels, class_labels],) + ) diff --git a/test/models/diffusion/test_song_unet.py b/test/models/diffusion/test_song_unet.py new file mode 100644 index 0000000000..048b2ef163 --- /dev/null +++ b/test/models/diffusion/test_song_unet.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import sys + +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common + +from modulus.models.diffusion import SongUNet as UNet + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_forward(device): + torch.manual_seed(0) + # Construct the DDM++ UNet model + model = UNet(img_resolution=64, in_channels=2, out_channels=2).to(device) + input_image = torch.ones([1, 2, 64, 64]).to(device) + noise_labels = noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + + assert common.validate_forward_accuracy( + model, + (input_image, noise_labels, class_labels), + file_name="ddmpp_unet_output.pth", + atol=1e-3, + ) + + torch.manual_seed(0) + # Construct the NCSN++ UNet model + model = UNet( + img_resolution=64, + in_channels=2, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + ).to(device) + + assert common.validate_forward_accuracy( + model, + (input_image, noise_labels, class_labels), + file_name="ncsnpp_unet_output.pth", + atol=1e-3, + ) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_constructor(device): + """Test the Song UNet constructor options""" + + # DDM++ + img_resolution = 16 + in_channels = 2 + out_channels = 2 + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + ).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + + # NCSN++ + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + ).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + + # Also test failure cases + try: + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + embedding_type=None, + ).to(device) + raise AssertionError("Failed to error for invalid argument") + except ValueError: + pass + + try: + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + encoder_type=None, + ).to(device) + raise AssertionError("Failed to error for invalid argument") + except ValueError: + pass + + try: + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + decoder_type=None, + ).to(device) + raise AssertionError("Failed to error for invalid argument") + except ValueError: + pass + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_optims(device): + """Test Song UNet optimizations""" + + def setup_model(): + model = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + ).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + return model, [input_image, noise_labels, class_labels] + + # Ideally always check graphs first + model, invar = setup_model() + assert common.validate_cuda_graphs(model, (*invar,)) + + # Check JIT + model, invar = setup_model() + assert common.validate_jit(model, (*invar,)) + # Check AMP + model, invar = setup_model() + assert common.validate_amp(model, (*invar,)) + # Check Combo + model, invar = setup_model() + assert common.validate_combo_optims(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_checkpoint(device): + """Test Song UNet checkpoint save/load""" + # Construct FNO models + model_1 = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + ).to(device) + + model_2 = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + ).to(device) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + assert common.validate_checkpoint( + model_1, model_2, (*[input_image, noise_labels, class_labels],) + ) + + +@common.check_ort_version() +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_son_unet_deploy(device): + """Test Song UNet deployment support""" + model = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + ).to(device) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + assert common.validate_onnx_export( + model, (*[input_image, noise_labels, class_labels],) + ) + assert common.validate_onnx_runtime( + model, (*[input_image, noise_labels, class_labels],) + )