From 5fbd02de9f6b8bbdb4e442f428d48885c29c1582 Mon Sep 17 00:00:00 2001 From: Paul Sweeney Date: Mon, 2 Oct 2023 11:16:11 +0100 Subject: [PATCH] Minor documentation updates and code optimisations. Code now outputs 'args' settings in main. Also, include a function to perform clahe to 2D slices of a 3D image. (#4) --- README.md | 6 +- custom_callback.py | 29 +++--- dataset.py | 233 ++++++++++++++++++++++++++++----------------- discriminator.py | 4 + loss_functions.py | 55 ++++++----- main.py | 19 +++- utils.py | 123 +++++++++++++++++++++--- vangan.py | 60 +++++++----- 8 files changed, 354 insertions(+), 175 deletions(-) diff --git a/README.md b/README.md index f71ff4c..897fb86 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Tensorflow and the remaining Python packages below can be installed in a [_conda The remaining required Python packages can be installed using _pip_ in a terminal window: ```bash -pip install opencv-python scikit-image tqdm tensorflow_addons tensorflow-mri joblib matplotlib +pip install opencv-python scikit-image tqdm tensorflow_addons joblib matplotlib ``` VAN-GAN has been tested on Ubuntu 22.04.2 LTS with Python 3.9.16 and the following package versions: @@ -59,10 +59,10 @@ VAN-GAN code was originally developed by [Paul W. Sweeney](https://www.psweeney. Please get in contact in you have any questions. -## References +## Citation If you use this code or data, we kindly ask that you please cite the below: > [Segmentation of 3D blood vessel networks using unsupervised deep learning](https://doi.org/10.1101/2023.04.30.538453)
-> Paul W. Sweeney et al. +> Paul W. Sweeney et al. *bioRxiv* ## Licence The project is licenced under the MIT Licence. diff --git a/custom_callback.py b/custom_callback.py index 79c9941..83d8b17 100644 --- a/custom_callback.py +++ b/custom_callback.py @@ -25,15 +25,11 @@ def __init__(self, self.imaging_val_data = imaging_val_data self.segmentation_val_data = segmentation_val_data self.process_imaging_domain = process_imaging_domain - self.period = args.PERIOD_2D_CALLBACK, - self.period3D = args.PERIOD_3D_CALLBACK, - self.model_path = args.output_dir, + self.period = args.PERIOD_2D_CALLBACK + self.period3D = args.PERIOD_3D_CALLBACK + self.model_path = args.output_dir self.dims = args.DIMENSIONS - self.period = self.period[0] - self.period3D = self.period3D[0] - self.model_path = self.model_path[0] - def save_model(self, model, epoch): """Save the trained model at the given epoch. @@ -43,10 +39,10 @@ def save_model(self, model, epoch): """ # if epoch > 100: - model.gen_AB.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genAB".format(epoch=epoch + 1))) - model.gen_BA.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genBA".format(epoch=epoch + 1))) - model.disc_A.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discA".format(epoch=epoch + 1))) - model.disc_B.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discB".format(epoch=epoch + 1))) + model.gen_IS.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genAB".format(epoch=epoch + 1))) + model.gen_SI.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genBA".format(epoch=epoch + 1))) + model.disc_I.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discA".format(epoch=epoch + 1))) + model.disc_S.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discB".format(epoch=epoch + 1))) def stitch_subvolumes(self, gen, img, subvol_size, epoch=-1, stride=(25, 25, 128), @@ -173,7 +169,7 @@ def stitch_subvolumes(self, gen, img, subvol_size, start_dep:(start_dep + kD)] if process_img and self.process_imaging_domain is not None: - arr = self.process_imaging_domain(arr) + arr = self.process_imaging_domain(arr, axis=None, keepdims=False) arr = gen(np.expand_dims(arr, axis=0), training=False)[0] @@ -419,14 +415,13 @@ def updateDiscriminatorNoise(self, model, init_noise, epoch, args): else: decay_rate = epoch / args.NO_NOISE noise = init_noise * (1. - decay_rate) + if noise < 0.0: + noise = 0.0 # noise = 0.9 ** (epoch + 1) print('Noise std: %0.5f' % noise) for layer in model.layers: - if type(layer) == layers.GaussianNoise: - if noise > 0.: - layer.stddev = noise - else: - layer.stddev = 0.0 + if isinstance(layer, tf.keras.layers.GaussianNoise): + layer.stddev = noise def on_epoch_start(self, model, epoch, args, logs=None): """ diff --git a/dataset.py b/dataset.py index cb6e48a..ec4891d 100644 --- a/dataset.py +++ b/dataset.py @@ -1,15 +1,23 @@ import random import math +import os import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from skimage import io -from utils import get_vacuum +from utils import get_vacuum, fast_clahe, clahe_3d class DatasetGen: - def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.distribute.Strategy, otf_imaging=None): + def __init__(self, + args, + imaging_domain_data, + seg_domain_data, + strategy: tf.distribute.Strategy, + otf_imaging=None, + semi_supervised_dir=None): """ Setting shard policy for distributed dataset """ + self.feature_indices = None options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA @@ -17,14 +25,12 @@ def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.dist if args.DIMENSIONS == 2: self.imaging_output_shapes = (None, None, args.CHANNELS) self.segmentation_output_shapes = (None, None, 1) - self.imaging_patch_shape = (args.GLOBAL_BATCH_SIZE, - args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.CHANNELS) + self.imaging_patch_shape = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.CHANNELS) self.segmentation_patch_shape = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], 1) else: self.imaging_output_shapes = (None, None, None, args.CHANNELS) self.segmentation_output_shapes = (None, None, None, 1) - self.imaging_patch_shape = (args.GLOBAL_BATCH_SIZE, - args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], + self.imaging_patch_shape = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], args.CHANNELS) self.segmentation_patch_shape = ( args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], 1) @@ -34,6 +40,11 @@ def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.dist self.segmentation_paths = seg_domain_data self.args = args self.otf_imaging = otf_imaging + if semi_supervised_dir is not None: + self.semi_supervised = True + self.semi_supervised_dir = semi_supervised_dir + else: + self.semi_supervised = False self.IMAGE_THRESH = 0.5 self.SEG_THRESH = 0.8 self.GLOBAL_BATCH_SIZE = args.GLOBAL_BATCH_SIZE @@ -48,9 +59,12 @@ def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.dist output_shapes=self.imaging_output_shapes) self.imaging_train_dataset = self.imaging_train_dataset.repeat() self.imaging_train_dataset = self.imaging_train_dataset.with_options(options) - self.imaging_train_dataset = self.imaging_train_dataset.batch(self.GLOBAL_BATCH_SIZE, drop_remainder=True) self.imaging_train_dataset = self.imaging_train_dataset.map(self.process_imaging_domain, num_parallel_calls=tf.data.AUTOTUNE) + self.imaging_train_dataset = self.imaging_train_dataset.batch(self.GLOBAL_BATCH_SIZE, drop_remainder=True) + if self.otf_imaging is not None: + self.imaging_train_dataset = self.imaging_train_dataset.map(self.otf_imaging, + num_parallel_calls=tf.data.AUTOTUNE) ''' Create imaging validation dataset ''' self.imaging_val_dataset = tf.data.Dataset.from_generator(lambda: self.imaging_datagen('validation'), @@ -58,9 +72,12 @@ def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.dist output_shapes=self.imaging_output_shapes) self.imaging_val_dataset = self.imaging_val_dataset.repeat() self.imaging_val_dataset = self.imaging_val_dataset.with_options(options) - self.imaging_val_dataset = self.imaging_val_dataset.batch(self.GLOBAL_BATCH_SIZE, drop_remainder=True) - self.imaging_val_dataset = self.imaging_val_dataset.map(map_func=self.process_imaging_domain, + self.imaging_val_dataset = self.imaging_val_dataset.map(self.process_imaging_domain, num_parallel_calls=tf.data.AUTOTUNE) + self.imaging_val_dataset = self.imaging_val_dataset.batch(self.GLOBAL_BATCH_SIZE, drop_remainder=True) + if self.otf_imaging is not None: + self.imaging_val_dataset = self.imaging_val_dataset.map(self.otf_imaging, + num_parallel_calls=tf.data.AUTOTUNE) ''' Create segmentation train dataset ''' self.segmentation_train_dataset = tf.data.Dataset.from_generator( @@ -126,11 +143,17 @@ def imaging_datagen(self, typ='training'): iter_i = 0 np.random.shuffle(img_dataset) - file = img_dataset[iter_i * self.args.GLOBAL_BATCH_SIZE:(iter_i + 1) * self.args.GLOBAL_BATCH_SIZE] + start_idx = iter_i * self.args.GLOBAL_BATCH_SIZE + end_idx = (iter_i + 1) * self.args.GLOBAL_BATCH_SIZE + + if end_idx > len(img_dataset): + end_idx = len(img_dataset) + + file = img_dataset[start_idx:end_idx] # Load batch of full size images for idx, filename in enumerate(file): - yield tf.convert_to_tensor(np.rot90(np.load(filename), np.random.choice([-1, 0, 1])), dtype=tf.float32) + yield tf.convert_to_tensor(np.load(filename), dtype=tf.float32) iter_i += 1 @@ -156,90 +179,101 @@ def segmentation_datagen(self, typ='training'): # Load batch of full size images for idx, filename in enumerate(file): - yield tf.convert_to_tensor(np.rot90(np.load(filename), - np.random.choice([-1, 0, 1])), dtype=tf.float32) + if self.semi_supervised: + ss_filename = os.path.join(self.semi_supervised_dir, os.path.basename(filename)) + yield tf.convert_to_tensor(np.concatenate((np.load(filename), + np.load(ss_filename)), + axis=0), + dtype=tf.float32) + else: + yield tf.convert_to_tensor(np.load(filename), dtype=tf.float32) iter_s += 1 def imaging_val_datagen(self): while True: - i = random.randint(0, self.imaging_paths['validation'].shape[0] - 1) + i = random.randint(0, len(self.imaging_paths['validation']) - 1) yield tf.convert_to_tensor(np.load(self.imaging_paths['validation'][i]), dtype=tf.float32), i def segmentation_val_datagen(self): while True: - i = random.randint(0, self.segmentation_paths['validation'].shape[0] - 1) + i = random.randint(0, len(self.segmentation_paths['validation']) - 1) yield tf.convert_to_tensor(np.load(self.segmentation_paths['validation'][i]), dtype=tf.float32), i ''' Functions for data preprocessing ''' + @tf.function + def random_spatial_augmentation(self, image, max_rotation_angle=180, preserve_depth_orientation=False): + # Randomly flip horizontally + image = tf.cond(tf.random.uniform(()) > 0.5, lambda: tf.image.flip_left_right(image), lambda: image) + + # Randomly flip vertically + image = tf.cond(tf.random.uniform(()) > 0.5, lambda: tf.image.flip_up_down(image), lambda: image) + + if not preserve_depth_orientation: + # Randomly rotate the image + rotation_angle = tf.random.uniform((), minval=-max_rotation_angle, maxval=max_rotation_angle) * ( + math.pi / 180.0) + image = tf.image.rot90(image, k=tf.cast(rotation_angle // 90, dtype=tf.int32)) + + return image + def process_imaging_domain(self, image): """ Standardizes image data and creates subvolumes """ - subvol = tf.image.random_crop(image, size=self.imaging_patch_shape) - if self.otf_imaging is not None: - subvol = self.otf_imaging(subvol) - return subvol + # subvol = tf.image.random_crop(image, size=self.imaging_patch_shape) + # if self.otf_imaging is not None: + # subvol = self.otf_imaging(subvol) + arr = tf.image.random_crop(image, size=self.imaging_patch_shape) + # arr = clahe_3d(arr) + return self.random_spatial_augmentation(arr, preserve_depth_orientation=True) @tf.function def process_seg_domain(self, image): - """ - Randomly crops the input_mask around a randomly selected feature voxel. - - Args: - self: - image (tf.Tensor): The 4D input segmentation mask (depth, width, length, channel). - Features are labeled as 1, background as -1. - - Returns: - cropped_mask (tf.Tensor): The randomly cropped segmentation mask. - """ - - # Get the indices of feature voxels - feature_indices = tf.where(tf.equal(image, 1)) - - # Randomly select a feature voxel - random_feature_index = tf.cast(tf.random.shuffle(feature_indices)[0], tf.int32) - - # Calculate the cropping window based on the selected feature voxel - crop_start_depth = random_feature_index[0] - self.segmentation_patch_shape[0] // 2 - crop_start_width = random_feature_index[1] - self.segmentation_patch_shape[1] // 2 - crop_start_length = random_feature_index[2] - self.segmentation_patch_shape[2] // 2 - - # Calculate crop_end coordinates symmetrically based on image dimensions - crop_end_depth = crop_start_depth + self.segmentation_patch_shape[0] - crop_end_width = crop_start_width + self.segmentation_patch_shape[1] - crop_end_length = crop_start_length + self.segmentation_patch_shape[2] - - image_shape = tf.shape(image) - - # Adjust cropping symmetrically if necessary - if crop_start_depth < 0: - crop_end_depth -= crop_start_depth - crop_start_depth = 0 - elif crop_end_depth > image_shape[0]: - crop_start_depth -= crop_end_depth - image_shape[0] - crop_end_depth = image_shape[0] - - if crop_start_width < 0: - crop_end_width -= crop_start_width - crop_start_width = 0 - elif crop_end_width > image_shape[1]: - crop_start_width -= crop_end_width - image_shape[1] - crop_end_width = image_shape[1] - - if crop_start_length < 0: - crop_end_length -= crop_start_length - crop_start_length = 0 - elif crop_end_length > image_shape[2]: - crop_start_length -= crop_end_length - image_shape[2] - crop_end_length = image_shape[2] - - # Crop the tensor - cropped_mask = image[crop_start_depth:crop_end_depth, - crop_start_width:crop_end_width, - crop_start_length:crop_end_length, :] - - return cropped_mask + # Initialize a loop counter + i = tf.constant(0) + + # Define the maximum number of iterations + max_iterations = tf.constant(200) + + # Initialize arr + arr = tf.image.random_crop(image, size=self.segmentation_patch_shape) + + # Start a while loop + def condition(i, arr): + return tf.math.logical_and(i < max_iterations, tf.math.reduce_max(arr) < self.SEG_THRESH) + + def body(i, _): + # Generate a new random crop from the original image + new_arr = tf.image.random_crop(image, size=self.segmentation_patch_shape) + return i + 1, new_arr + + _, arr = tf.while_loop(condition, body, [i, arr]) + + return self.random_spatial_augmentation(arr) + + # @tf.function + # def process_imaging_domain(self, image): + # # Initialize a loop counter + # i = tf.constant(0) + # + # # Define the maximum number of iterations + # max_iterations = tf.constant(10) + # + # # Initialize arr + # arr = tf.image.random_crop(image, size=self.imaging_patch_shape) + # + # # Start a while loop + # def condition(i, arr): + # return tf.math.logical_and(i < max_iterations, tf.math.reduce_max(arr) < 0.) + # + # def body(i, _): + # # Generate a new random crop from the original image + # new_arr = tf.image.random_crop(image, size=self.imaging_patch_shape) + # return i + 1, new_arr + # + # _, arr = tf.while_loop(condition, body, [i, arr]) + # + # return self.random_spatial_augmentation(arr) def plot_sample_dataset(self): """ @@ -262,12 +296,18 @@ def plot_sample_dataset(self): else: nfig = 6 - fig, axs = plt.subplots(nfig + 1, 2, figsize=(10, 15)) + if self.semi_supervised: + fig, axs = plt.subplots(nfig + 1, 3, figsize=(10, 15)) + else: + fig, axs = plt.subplots(nfig + 1, 2, figsize=(10, 15)) fig.subplots_adjust(hspace=0.5) for i, samples in enumerate(zip(self.imaging_train_dataset.take(1), self.segmentation_train_dataset.take(1))): dI = samples[0][0].numpy() dS = samples[1][0].numpy() + if self.semi_supervised: + dIS = dS[self.segmentation_patch_shape[0]:, ] + dS = dS[:self.segmentation_patch_shape[0], ] if self.args.DIMENSIONS == 3: ''' Save 3D images ''' io.imsave("./GANMonitor/Imaging_Test_Input.tiff", @@ -285,33 +325,50 @@ def plot_sample_dataset(self): axs[0, 1].imshow(showS, cmap='gray') else: for j in range(0, nfig): - showI = (dI[:, :, j * int(self.args.SUBVOL_PATCH_SIZE[2] / nfig), ]) - showS = (dS[:, :, j * int(self.args.SUBVOL_PATCH_SIZE[2] / nfig), ]) + showI = (dI[:, :, j * int(self.segmentation_patch_shape[2] / nfig), ]) + showS = (dS[:, :, j * int(self.segmentation_patch_shape[2] / nfig), ]) axs[j, 0].imshow(showI, cmap='gray') axs[j, 1].imshow(showS, cmap='gray') + if self.semi_supervised: + showIS = (dIS[:, :, j * int(self.segmentation_patch_shape[2] / nfig), ]) + axs[j, 2].imshow(showIS, cmap='gray') ''' Include histograms ''' axs[nfig, 0].hist(dI.ravel(), bins=256, range=(np.amin(dI), np.amax(dI)), fc='k', ec='k', density=True) axs[nfig, 1].hist(dS.ravel(), bins=256, range=(np.amin(dS), np.amax(dS)), fc='k', ec='k', density=True) + if self.semi_supervised: + axs[nfig, 2].hist(dIS.ravel(), bins=256, range=(np.amin(dIS), np.amax(dIS)), fc='k', ec='k', + density=True) # Set axis labels - axs[0, 0].set_title('Imaging Dataset Example (XY Slices)') - axs[0, 1].set_title('Segmentation Dataset Example (XY Slices)') + axs[0, 0].set_title('Imaging Dataset (XY)') + axs[0, 1].set_title('Segmentation Dataset (XY)') + if self.semi_supervised: + axs[0, 2].set_title('Paired Imaging Dataset (XY)') axs[nfig, 0].set_ylabel('Voxel Frequency') plt.show(block=False) plt.close() if self.args.DIMENSIONS == 3: - _, axs = plt.subplots(nfig, 2, figsize=(10, 15)) + if self.semi_supervised: + _, axs = plt.subplots(nfig, 3, figsize=(10, 15)) + else: + _, axs = plt.subplots(nfig, 2, figsize=(10, 15)) for j in range(0, nfig): - showI = dI[:, j * int(self.args.SUBVOL_PATCH_SIZE[1] / nfig), :, 0] - showS = dS[:, j * int(self.args.SUBVOL_PATCH_SIZE[1] / nfig), :self.args.SUBVOL_PATCH_SIZE[2] - 1, - 0] + showI = dI[:, j * int(self.segmentation_patch_shape[1] / nfig), :, ] + showS = dS[:, j * int(self.segmentation_patch_shape[1] / nfig), + :self.args.SUBVOL_PATCH_SIZE[2] - 1, ] axs[j, 0].imshow(showI, cmap='gray') axs[j, 1].imshow(showS, cmap='gray') + if self.semi_supervised: + showIS = dIS[:, j * int(self.segmentation_patch_shape[1] / nfig), + :self.args.SUBVOL_PATCH_SIZE[2] - 1, ] + axs[j, 2].imshow(showIS, cmap='gray') # Set axis labels - axs[0, 0].set_title('Imaging Dataset Example (YZ Slices)') - axs[0, 1].set_title('Segmentation Dataset Example (YZ Slices)') + axs[0, 0].set_title('Imaging Dataset (YZ)') + axs[0, 1].set_title('Segmentation Dataset (YZ)') + if self.semi_supervised: + axs[0, 2].set_title('Paired Dataset (YZ)') plt.show(block=False) plt.close() diff --git a/discriminator.py b/discriminator.py index df14c1f..d3bd3b5 100644 --- a/discriminator.py +++ b/discriminator.py @@ -11,10 +11,12 @@ def get_discriminator( kernel_initializer='he_normal', num_downsampling=3, use_dropout=False, + dropout_rate=0.2, wasserstein=False, use_SN=False, use_input_noise=False, use_layer_noise=False, + use_standardisation=False, name=None, noise_std=0.1 ): @@ -80,6 +82,7 @@ def get_discriminator( kernel_size=(4, 4, 4), strides=(2, 2, 2), use_dropout=use_dropout, + dropout_rate=dropout_rate, use_spec_norm=use_SN, use_layer_noise=use_layer_noise, noise_std=noise_std @@ -92,6 +95,7 @@ def get_discriminator( kernel_size=(4, 4, 4), strides=(1, 1, 1), use_dropout=use_dropout, + dropout_rate=dropout_rate, padding='same', use_spec_norm=use_SN, use_layer_noise=use_layer_noise, diff --git a/loss_functions.py b/loss_functions.py index 9fa114b..e366ca5 100644 --- a/loss_functions.py +++ b/loss_functions.py @@ -1,6 +1,6 @@ import tensorflow as tf import numpy as np -from utils import min_max_norm_tf +from utils import min_max_norm_tf, z_score_norm_tf from clDice_func import soft_dice_cldice_loss @@ -17,7 +17,9 @@ def reduce_mean(self, inputs, axis=None, keepdims=False): Returns: - A tensor with the mean of the inputs tensor along the given axis divided by the global batch size. """ - return tf.reduce_mean(inputs, axis=axis, keepdims=keepdims) / self.global_batch_size + + arr = tf.reduce_mean(inputs, axis=axis, keepdims=keepdims) + return tf.reduce_sum(arr) / self.global_batch_size @tf.function @@ -70,23 +72,23 @@ def MSE(self, y_true, y_pred): def L4(self, y_true, y_pred): """ Compute the per-sample L4 loss between the true and predicted tensors. - + Args: - y_true: A tensor of true values. - y_pred: A tensor of predicted values. - + Returns: - A scalar tensor representing the per-sample L4 loss between the true and predicted tensors. """ return reduce_mean(self, tf.math.pow(y_true - y_pred, 4), axis=list(range(1, len(y_true.shape)))) + @tf.function def ssim_loss_3d(y_true, y_pred, max_val=1.0, filter_size=3, filter_sigma=1.5, k1=0.01, k2=0.03): - # Create Gaussian filter def gaussian_filter(size, sigma): grid = tf.range(-size // 2 + 1, size // 2 + 1, dtype=tf.float32) - gaussian_filter = tf.exp(-0.5 * (grid / sigma)**2) / (sigma * tf.sqrt(2.0 * np.pi)) + gaussian_filter = tf.exp(-0.5 * (grid / sigma) ** 2) / (sigma * tf.sqrt(2.0 * np.pi)) return gaussian_filter / tf.reduce_sum(gaussian_filter) # Create 3D Gaussian filter @@ -97,32 +99,33 @@ def gaussian_filter(size, sigma): # Compute mean and variance mu_true = tf.nn.conv3d(y_true, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') mu_pred = tf.nn.conv3d(y_pred, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_true_sq = mu_true**2 - mu_pred_sq = mu_pred**2 + mu_true_sq = mu_true ** 2 + mu_pred_sq = mu_pred ** 2 mu_true_pred = mu_true * mu_pred - sigma_true_sq = tf.nn.conv3d(y_true**2, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_true_sq - sigma_pred_sq = tf.nn.conv3d(y_pred**2, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_pred_sq + sigma_true_sq = tf.nn.conv3d(y_true ** 2, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_true_sq + sigma_pred_sq = tf.nn.conv3d(y_pred ** 2, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_pred_sq sigma_true_pred = tf.nn.conv3d(y_true * y_pred, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_true_pred - c1 = (k1 * max_val)**2 - c2 = (k2 * max_val)**2 + c1 = (k1 * max_val) ** 2 + c2 = (k2 * max_val) ** 2 - ssim_map = (2 * mu_true_pred + c1) * (2 * sigma_true_pred + c2) / ((mu_true_sq + mu_pred_sq + c1) * (sigma_true_sq + sigma_pred_sq + c2)) + ssim_map = (2 * mu_true_pred + c1) * (2 * sigma_true_pred + c2) / ( + (mu_true_sq + mu_pred_sq + c1) * (sigma_true_sq + sigma_pred_sq + c2)) # Compute the mean SSIM loss across the batch - return 1.0 - tf.reduce_mean(ssim_map) + return 1.0 - ssim_map @tf.function def wasserstein_loss(prob_real_is_real, prob_fake_is_real): """ Compute the Wasserstein loss between the probabilities that the real inputs are real and the generated inputs are real. - + Args: - prob_real_is_real: A tensor representing the probability that the real inputs are real. - prob_fake_is_real: A tensor representing the probability that the generated inputs are real. - + Returns: - A scalar tensor representing the Wasserstein loss between the two input probability tensors. """ @@ -176,11 +179,14 @@ def cycle_loss(self, real_image, cycled_image, typ=None): elif typ == "mse": return MSE(self, real_image, cycled_image) * self.lambda_cycle elif typ == "L4": - return L4(self, real_image, cycled_image) * self.lambda_cycle + return L4(self, + real_image, + cycled_image) * self.lambda_cycle else: real = min_max_norm_tf(real_image, axis=(1, 2, 3, 4)) cycled = min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4)) loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE) + # loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE) return reduce_mean(self, loss_obj(real, cycled)) * self.lambda_cycle @@ -196,27 +202,28 @@ def cycle_reconstruction(self, real_image, cycled_image): Returns: - loss: float Tensor, representing the per sample cycle reconstruction loss """ - real = min_max_norm_tf(real_image, axis=(1, 2, 3, 4)) - cycled = min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4)) - return reduce_mean(self, ssim_loss_3d(real, cycled, max_val=1.0)) * self.lambda_cycle + return reduce_mean(self, + ssim_loss_3d(min_max_norm_tf(real_image, axis=(1, 2, 3, 4)), + min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4)), max_val=1.0) + ) * self.lambda_reconstruction @tf.function def cycle_seg_loss(self, real_image, cycled_image): """ Compute the segmentation loss between the real image and the cycled image - + Args: - real_image: a tensor of shape (batch_size, image_size, image_size, channels) representing the real image - cycled_image: a tensor of shape (batch_size, image_size, image_size, channels) representing the cycled image - + Returns: - a scalar tensor representing the segmentation loss """ real = min_max_norm_tf(real_image, axis=(1, 2, 3, 4)) cycled = min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4)) cl_loss_obj = soft_dice_cldice_loss() - return reduce_mean(self, cl_loss_obj(real, cycled)) * self.lambda_cycle + return cl_loss_obj(real, cycled) * (self.lambda_topology / self.n_devices) @tf.function @@ -289,7 +296,7 @@ def discriminator_loss_fn(self, real_image, fake_image, typ=None, from_logits=Tr real_image: A tensor representing the real image. fake_image: A tensor representing the fake image. typ (str, optional): The type of loss function to use. Defaults to None. - from_logits (bool, optional): Whether to apply sigmoid activation function to the predictions. + from_logits (bool, optional): Whether to apply sigmoid activation function to the predictions. Defaults to True. Returns: diff --git a/main.py b/main.py index 76f948c..6c2bbbb 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,15 @@ # strategy = tf.distribute.experimental.CentralStorageStrategy() # strategy = tf.distribute.OneDeviceStrategy(device='GPU:0') +from time import time +from vangan import VanGan, train +from custom_callback import GanMonitor +from dataset import DatasetGen +from preprocessing import DataPreprocessor +from tb_callback import TB_Summary +from utils import min_max_norm_tf, rescale_arr_tf, z_score_norm, threshold_outliers, save_args, min_max_norm, clahe_3d +from post_training import epoch_sweep +from scipy.ndimage import median_filter ''' TENSORFLOW DEBUGGING ''' # tf.config.set_soft_device_placement(True) @@ -166,10 +175,11 @@ def preprocess_rsom_images(img, lower_thresh=0.05, upper_thresh=99.95): # Define function to preprocess imaging domain image on the fly (otf) # Min/max batch normalisation and rescaling to [-1,1] shown here @tf.function -def process_imaging_otf(tensor): +def process_imaging_otf(tensor, axis=(1, 2, 3, 4), keepdims=True): + # Calculate the maximum and minimum values along the batch dimension - max_vals = tf.reduce_max(tensor, axis=(1, 2, 3, 4), keepdims=True) - min_vals = tf.reduce_min(tensor, axis=(1, 2, 3, 4), keepdims=True) + max_vals = tf.reduce_max(tensor, axis=axis, keepdims=keepdims) + min_vals = tf.reduce_min(tensor, axis=axis, keepdims=keepdims) # Normalize the tensor between -1 and 1 return 2.0 * (tensor - min_vals) / (max_vals - min_vals) - 1.0 @@ -205,6 +215,9 @@ def process_imaging_otf(tensor): process_imaging_domain=process_imaging_otf ) +# Save args to txt file +save_args(args, os.path.join(args.output_dir, 'Args_Settings.txt')) + ''' TRAIN VAN-GAN MODEL ''' for epoch in range(args.EPOCHS): print(f'\nEpoch {epoch + 1:03d}/{args.EPOCHS:03d}') diff --git a/utils.py b/utils.py index e9e4f40..37b316c 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,10 @@ import tensorflow as tf import skimage.io as sk from skimage import exposure - +from scipy import stats +import tf_clahe +import mclahe as mc +import tensorflow_addons as tfa def min_max_norm(data): """ @@ -39,14 +42,12 @@ def min_max_norm_tf(arr, axis=None): # Normalize entire array min_val = tf.reduce_min(arr) max_val = tf.reduce_max(arr) - tensor = (arr - min_val) / (max_val - min_val) else: # Normalize along a specific axis min_val = tf.reduce_min(arr, axis=axis, keepdims=True) max_val = tf.reduce_max(arr, axis=axis, keepdims=True) - tensor = (arr - min_val) / (max_val - min_val) - return tensor + return (arr - min_val) / (max_val - min_val) def rescale_arr_tf(arr, alpha=-0.5, beta=0.5): @@ -84,6 +85,28 @@ def z_score_norm(data): # raise ValueError("Cannot perform z-score normalization when the standard deviation is zero.") +import tensorflow as tf + + +def z_score_norm_tf(data, epsilon=1e-8): + """ + Perform z-score normalization on a TensorFlow tensor. + + Args: + - data (tf.Tensor): A TensorFlow tensor containing the data to be normalized. + Shape should be (batch, depth, width, length, channel). + - epsilon (float): A small value to avoid division by zero when std_data is close to zero. + + Returns: + - tf.Tensor: A TensorFlow tensor containing the normalized data. + Shape will be the same as the input. + """ + mean_data = tf.math.reduce_mean(data, axis=(1, 2, 3, 4), keepdims=True) + std_data = tf.math.reduce_std(data, axis=(1, 2, 3, 4), keepdims=True) + + return (data - mean_data) / tf.where(std_data > epsilon, std_data, epsilon) + + def threshold_outliers(image_volume, threshold=6): """ Thresholds outlier voxels in the input 3D image volume. @@ -152,6 +175,7 @@ def binarise_tensor(arr): tf.ones(tf.shape(arr)), tf.math.negative(tf.ones(tf.shape(arr)))) + def add_gauss_noise(self, img, rate): """ Add Gaussian noise to a TensorFlow image tensor. @@ -165,6 +189,20 @@ def add_gauss_noise(self, img, rate): """ return tf.clip_by_value(img + tf.random.normal(tf.shape(img), 0.0, rate), -1., 1.) + +def clip_images(images): + """ + Clips input images to the range of [-1, 1]. + + Args: + images: Input image batch tensor. + + Returns: + Clipped image batch tensor. + """ + return tf.clip_by_value(images, clip_value_min=-1.0, clip_value_max=1.0) + + def load_volume(file, size=(600, 600, 700), datatype='uint8', normalise=True): """ Load a volume from a (for example) tif file and normalise it. @@ -219,7 +257,7 @@ def resize_volume(img, target_size=None): return arr2 -def get_vaccuum(arr, dim): +def get_vacuum(arr, dim=3): """ Returns the smallest subarray containing all non-zero elements in the input array along the specified dimension(s). @@ -231,10 +269,10 @@ def get_vaccuum(arr, dim): numpy.ndarray: Subarray containing all non-zero elements in the input array along the specified dimension(s). """ if dim == 2: - x, y = np.nonzero(arr) + x, y, _ = np.nonzero(arr) return arr[x.min():x.max() + 1, y.min():y.max() + 1] else: - x, y, z = np.nonzero(arr) + x, y, z, _ = np.nonzero(arr) return arr[x.min():x.max() + 1, y.min():y.max() + 1, z.min():z.max() + 1] @@ -284,10 +322,9 @@ def append_dict(dict1, dict2, replace=False) -> dict: """ Append items in dict2 to dict1. - Args: - - dict1 (dict): The dictionary to which items in dict2 will be appended - - dict2 (dict): The dictionary containing items to be appended to dict1 - - replace (bool): If True, existing values in dict1 with the same key as values in dict2 will be replaced with the values from dict2 + Args: - dict1 (dict): The dictionary to which items in dict2 will be appended - dict2 (dict): The dictionary + containing items to be appended to dict1 - replace (bool): If True, existing values in dict1 with the same key as + values in dict2 will be replaced with the values from dict2 Returns: - dict: A dictionary containing the appended items @@ -357,8 +394,64 @@ def get_shape(arr): arr = arr[0] # set arr to the first element of arr return res # return the shape list -# import tf_clahe -# @tf.function(experimental_compile=True) # Enable XLA -# def fast_clahe(img): -# return tf_clahe.clahe(img, tile_grid_size=(4, 4), gpu_optimized=True) +@tf.function +def fast_clahe(img, gpu_optimized=True): + return tf_clahe.clahe(img, clip_limit=1.5, gpu_optimized=gpu_optimized) + +@tf.function +def clahe_3d(image): + """ + Applies 3D Contrast Limited Adaptive Histogram Equalization (CLAHE) to a 3D image. + + Args: + image (tf.Tensor): Input 3D image of shape (batch_size, width, length, depth, channels). + clip_limit (float): Clip limit for CLAHE. + grid_size (tuple): Size of the grid for histogram equalization (depth, width, length). + num_bins (int): Number of bins in the histogram. + + Returns: + tf.Tensor: Processed 3D image. + """ + # Extract dimensions + batch_size, width, length, depth, channels = image.shape + + # Initialize a list to hold the processed slices + processed_slices = [] + + # Create a CLAHE op for each depth slice and append it to the list + for d in range(depth): + slice_image = image[:, :, :, d, :] + + # Apply CLAHE to the slice using fast_clahe function + # clahe = tfa.image.median_filter2d( + # fast_clahe(slice_image), + # filter_shape=(2, 2) + # ) + clahe = fast_clahe(slice_image) + + # Append the processed slice to the list + processed_slices.append(clahe) + + # Stack the processed slices to form the final 3D image + processed_image = tf.stack(processed_slices, axis=3) + + return processed_image + + +def save_args(args, filename): + def format_value(value): + if isinstance(value, tuple): + return f"({', '.join(map(str, value))})" + return str(value) + + # Filter out attributes that are not argparse arguments + arg_dict = {arg: value for arg, value in vars(args).items() if not arg.startswith('_')} + + with open(filename, "w") as f: + f.write("Command line arguments:\n") + for arg, value in arg_dict.items(): + formatted_value = format_value(value) + f.write(f"{arg}: {formatted_value}\n") + + diff --git a/vangan.py b/vangan.py index 8371e7e..3c7cc29 100644 --- a/vangan.py +++ b/vangan.py @@ -24,8 +24,11 @@ def __init__( strategy, lambda_cycle=10.0, lambda_identity=5, + lambda_reconstruction=5, + lambda_topology=5, gen_i2s='resnet', gen_s2i='resnet', + semi_supervised=False, wasserstein=False, ncritic=5, gp_weight=10.0 @@ -34,10 +37,13 @@ def __init__( self.n_devices = args.N_DEVICES self.img_size = args.INPUT_IMG_SIZE self.lambda_cycle = lambda_cycle - self.lambda_identity = lambda_identity, + self.lambda_identity = lambda_identity + self.lambda_reconstruction = lambda_reconstruction + self.lambda_topology = lambda_topology self.channels = args.CHANNELS self.gen_i2s_typ = gen_i2s self.gen_s2i_typ = gen_s2i + self.semi_supervised = semi_supervised self.global_batch_size = args.GLOBAL_BATCH_SIZE self.dims = args.DIMENSIONS if self.dims == 2: @@ -115,7 +121,7 @@ def __init__( # output_activation=None, ) else: - raise ValueError('AB Generator type not recognised') + raise ValueError('IS Generator type not recognised') if self.gen_s2i_typ == 'resnet': self.gen_SI = get_resnet_generator( @@ -155,7 +161,7 @@ def __init__( use_input_noise=False ) else: - raise ValueError('BA Generator type not recognised') + raise ValueError('SI Generator type not recognised') # Get the discriminators self.disc_I = get_discriminator( @@ -163,7 +169,8 @@ def __init__( batch_size=self.global_batch_size, name='discriminator_I', filters=64, - use_dropout=False, + use_dropout=True, + dropout_rate=0.2, wasserstein=self.wasserstein, use_SN=False, use_input_noise=True, @@ -175,7 +182,8 @@ def __init__( batch_size=self.global_batch_size, name='discriminator_S', filters=64, - use_dropout=False, + use_dropout=True, + dropout_rate=0.2, wasserstein=self.wasserstein, use_SN=False, use_input_noise=True, @@ -227,14 +235,14 @@ def __init__( clipnorm=100) # Initialise checkpoint - self.checkpoint = tf.train.Checkpoint(gen_AB=self.gen_IS, - gen_BAF=self.gen_SI, - disc_A=self.disc_I, - disc_B=self.disc_S, - gen_A_optimizer=self.gen_I_optimizer, - gen_B_optimizer=self.gen_S_optimizer, - disc_A_optimizer=self.disc_I_optimizer, - disc_B_optimizer=self.disc_S_optimizer) + self.checkpoint = tf.train.Checkpoint(gen_IS=self.gen_IS, + gen_SI=self.gen_SI, + disc_I=self.disc_I, + disc_S=self.disc_S, + gen_I_optimizer=self.gen_I_optimizer, + gen_S_optimizer=self.gen_S_optimizer, + disc_I_optimizer=self.disc_I_optimizer, + disc_S_optimizer=self.disc_S_optimizer) def save_checkpoint(self, epoch): """ save checkpoint to checkpoint_dir, overwrite if exists """ @@ -246,14 +254,16 @@ def load_checkpoint(self, epoch=None, expect_partial: bool = False, newpath=None if newpath is not None: self.checkpoint_prefix = os.path.join(newpath, 'checkpoint') checkpoint_path = self.checkpoint_prefix + "_e{epoch}".format(epoch=epoch) - if os.path.exists(f'{os.path.join(checkpoint_path)}.index'): - self.checkpoint_loaded = True - with self.strategy.scope(): - if expect_partial: - self.checkpoint.read(checkpoint_path).expect_partial() - else: - self.checkpoint.read(checkpoint_path) - print(f'\nLoaded checkpoint from {checkpoint_path}\n') + + print(f"Trying to load checkpoint from path: {checkpoint_path}") + checkpoint_files = [f'{checkpoint_path}.index', f'{checkpoint_path}.data-00000-of-00001'] + if all(os.path.exists(file) for file in checkpoint_files): + if expect_partial: + self.checkpoint.restore(checkpoint_path).expect_partial() + else: + self.checkpoint.restore(checkpoint_path) + print(f'Loaded checkpoint from {checkpoint_path}\n') + else: print('Error: Checkpoint not found!') @@ -293,9 +303,9 @@ def compute_losses(self, real_I, real_S, result, training=True): seg_loss = self.seg_loss_fn(self, real_S, cycled_S) cycled_I = self.gen_SI(fake_S, training=training) - cycle_loss_S = self.cycle_loss_fn(self, real_I, cycled_I, typ='L2') + cycle_loss_S = self.cycle_loss_fn(self, real_I, cycled_I, typ='mse') - reconstruction_loss_I = self.reconstruction_loss(self, real_I, cycled_I) + reconstruction_loss = self.reconstruction_loss(self, real_I, cycled_I) # Identity mapping # id_SI_loss = self.identity_loss_fn(self, real_I, self.gen_SI(real_I, training=True)) @@ -332,8 +342,8 @@ def compute_losses(self, real_I, real_S, result, training=True): 'D_S_loss': disc_S_loss, 'gen_IS_loss': gen_IS_loss, 'gen_SI_loss': gen_SI_loss, - 'cycle_gen_IS_loss': cycle_loss_I, - 'cycle_gen_SI_loss': cycle_loss_S, + 'cycle_gen_SIS_loss': cycle_loss_I, + 'cycle_gen_ISI_loss': cycle_loss_S, 'seg_loss': seg_loss, 'reconstruction_loss_I': reconstruction_loss_I, # 'identity_IS': id_IS_loss,