From 83ec8287ca7e2c38ac18699da5247406dcf65473 Mon Sep 17 00:00:00 2001 From: sweene01 Date: Fri, 21 Jul 2023 11:04:07 +0100 Subject: [PATCH] Housekeeping. Simplified minor sections of code to reduce duplicate sections. Adjusted several variable and parameters names to improve interpretability. Remove redundant initializer assignment and assignment 'he_normal' as default. --- building_blocks.py | 159 ++++++------ clDice_func.py | 135 +++++----- custom_callback.py | 352 +++++++++++++------------ dataset.py | 205 ++++++++------- discriminator.py | 50 ++-- generator.py | 44 ++-- loss_functions.py | 18 +- main.py | 176 ++++++------- post_training.py | 20 +- preprocessing.py | 146 +++++------ resunet_model.py | 163 +++++++----- tb_callback.py | 67 ++--- utils.py | 98 ++++--- vangan.py | 635 ++++++++++++++++++++++----------------------- vnet_model.py | 146 +++++++---- 15 files changed, 1268 insertions(+), 1146 deletions(-) diff --git a/building_blocks.py b/building_blocks.py index 55d9d63..2182125 100644 --- a/building_blocks.py +++ b/building_blocks.py @@ -4,12 +4,13 @@ import tensorflow_addons as tfa -def npy_padding(x, padding=(1,1,1), padtype='reflect'): - return np.pad(x, ((padding[0],padding[0]), - (padding[1],padding[1]), - (padding[2],padding[2])), + +def npy_padding(x, padding=(1, 1, 1)): + return np.pad(x, ((padding[0], padding[0]), + (padding[1], padding[1]), + (padding[2], padding[2])), 'reflect') - + class ReflectionPadding3D(layers.Layer): """Implements Reflection Padding as a layer. @@ -26,7 +27,7 @@ def __init__(self, padding=(1, 1, 1), **kwargs): self.padding = tuple(padding) super(ReflectionPadding3D, self).__init__(**kwargs) - def call(self, input_tensor, mask=None): + def call(self, input_tensor): padding_width, padding_height, padding_depth = self.padding padding_tensor = [ [0, 0], @@ -36,7 +37,8 @@ def call(self, input_tensor, mask=None): [0, 0], ] return tf.pad(input_tensor, padding_tensor, mode="REFLECT") - + + class ReflectionPadding2D(layers.Layer): """Implements Reflection Padding as a layer. @@ -52,7 +54,7 @@ def __init__(self, padding=(1, 1), **kwargs): self.padding = tuple(padding) super(ReflectionPadding2D, self).__init__(**kwargs) - def call(self, input_tensor, mask=None): + def call(self, input_tensor): padding_width, padding_height = self.padding padding_tensor = [ [0, 0], @@ -60,17 +62,18 @@ def call(self, input_tensor, mask=None): [padding_width, padding_width], [0, 0], ] - return tf.pad(input_tensor, padding_tensor, mode="REFLECT") - + return tf.pad(input_tensor, padding_tensor, mode="REFLECT") + + def residual_block( - x, - activation, - kernel_initializer=None, - kernel_size=(3, 3, 3), - strides=(1, 1, 1), - padding="valid", - gamma_initializer=None, - use_bias=False + x, + activation, + kernel_initializer=None, + kernel_size=(3, 3, 3), + strides=(1, 1, 1), + padding="valid", + gamma_initializer=None, + use_bias=False ): """ Defines a residual block for use in a 3D convolutional neural network. @@ -119,21 +122,22 @@ def residual_block( x = layers.add([input_tensor, x]) return x + def downsample( - x, - filters, - activation, - kernel_initializer=None, - kernel_size=(3, 3, 3), - strides=(2, 2, 2), - padding="valid", - gamma_initializer=None, - use_bias=False, - use_dropout=True, - use_SN=False, - padding_size=(1, 1, 1), - use_layer_noise=False, - noise_std=0.1 + x, + filters, + activation, + kernel_initializer='he_normal', + kernel_size=(3, 3, 3), + strides=(2, 2, 2), + padding="valid", + gamma_initializer=None, + use_bias=False, + use_dropout=True, + use_spec_norm=False, + padding_size=(1, 1, 1), + use_layer_noise=False, + noise_std=0.1 ): """ Downsamples an input tensor using a 3D convolutional layer. @@ -149,7 +153,7 @@ def downsample( gamma_initializer (str, optional): Gamma initializer for InstanceNormalization. Defaults to None. use_bias (bool, optional): Whether to use bias in the convolutional layer. Defaults to False. use_dropout (bool, optional): Whether to use dropout after activation. Defaults to True. - use_SN (bool, optional): Whether to use Spectral Normalization. Defaults to False. + use_spec_norm (bool, optional): Whether to use Spectral Normalization. Defaults to False. padding_size (tuple of ints, optional): Padding size for ReflectionPadding3D. Defaults to (1, 1, 1). use_layer_noise (bool, optional): Whether to add Gaussian noise after ReflectionPadding3D. Defaults to False. noise_std (float, optional): Standard deviation of Gaussian noise. Defaults to 0.1. @@ -157,14 +161,14 @@ def downsample( Returns: Tensor: The downsampled tensor. """ - + if padding == 'valid': x = ReflectionPadding3D(padding_size)(x) - + if use_layer_noise: x = layers.GaussianNoise(noise_std)(x) - if use_SN: + if use_spec_norm: x = tfa.layers.SpectralNormalization(layers.Conv3D( filters, kernel_size, @@ -183,37 +187,37 @@ def downsample( use_bias=use_bias )(x) x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) - + if activation: x = activation(x) if use_dropout: x = layers.SpatialDropout3D(0.2)(x) return x + def deconv( - x, - filters, - activation, - kernel_size=(4,4,4), - strides=(2, 2, 2), - padding="same", - kernel_initializer=None, - gamma_initializer=None, - use_bias=False, + x, + filters, + activation, + kernel_size=(4, 4, 4), + strides=(2, 2, 2), + padding="same", + kernel_initializer='he_normal', + gamma_initializer=None, + use_bias=False, ): """ 3D deconvolution on the input tensor `x` using transpose convolutional layers. - Args: - x (tf.Tensor): Input tensor of shape [batch_size, height, width, depth, channels] - filters (int): Number of output filters in the convolutional layer. - activation (Callable, optional): Activation function to use. If `None`, no activation is applied. - kernel_size (tuple, optional): Size of the 3D convolutional kernel. Defaults to (4, 4, 4). - strides (tuple, optional): The strides of the deconvolution. Defaults to (2, 2, 2). - padding (str, optional): The type of padding to apply. Defaults to 'same'. - kernel_initializer (tf.keras.initializers.Initializer, optional): Initializer for the kernel weights. Defaults to None. - gamma_initializer (tf.keras.initializers.Initializer, optional): Initializer for the gamma weights of instance normalization layer. Defaults to None. - use_bias (bool, optional): Whether to include a bias term in the convolutional layer. Defaults to False. + Args: x (tf.Tensor): Input tensor of shape [batch_size, height, width, depth, channels] filters (int): Number of + output filters in the convolutional layer. activation (Callable, optional): Activation function to use. If + `None`, no activation is applied. kernel_size (tuple, optional): Size of the 3D convolutional kernel. Defaults to + (4, 4, 4). strides (tuple, optional): The strides of the deconvolution. Defaults to (2, 2, 2). padding (str, + optional): The type of padding to apply. Defaults to 'same'. kernel_initializer ( + tf.keras.initializers.Initializer, optional): Initializer for the kernel weights. Defaults to None. + gamma_initializer (tf.keras.initializers.Initializer, optional): Initializer for the gamma weights of instance + normalization layer. Defaults to None. use_bias (bool, optional): Whether to include a bias term in the + convolutional layer. Defaults to False. Returns: tf.Tensor: Output tensor of shape [batch_size, height, width, depth, filters]. @@ -231,41 +235,40 @@ def deconv( x = activation(x) return x + def upsample( - x, - filters, - activation, - kernel_size=(4,4,4), - strides=(2, 2, 2), - padding="same", - kernel_initializer=None, - gamma_initializer=None, - use_bias=False, + x, + filters, + activation, + kernel_size=(4, 4, 4), + strides=(1, 1, 1), + padding="same", + kernel_initializer='he_normal', + gamma_initializer=None, + use_bias=False, ): """ Upsamples the input tensor using 3D transposed convolution and applies instance normalization. - Args: - x (tf.Tensor): The input tensor. - filters (int): The dimensionality of the output space. - activation (Optional[Callable]): The activation function to use. Defaults to None. - kernel_size (Tuple[int, int, int]): The size of the 3D transposed convolution window. Defaults to (4, 4, 4). - strides (Tuple[int, int, int]): The strides of the 3D transposed convolution. Defaults to (2, 2, 2). - padding (str): The type of padding to use. Defaults to 'same'. - kernel_initializer (Optional[Callable]): The initializer for the kernel weights. Defaults to None. - gamma_initializer (Optional[Callable]): The initializer for the gamma weights of the instance normalization layer. Defaults to None. - use_bias (bool): Whether to include a bias vector in the convolution layer. Defaults to False. + Args: x (tf.Tensor): The input tensor. filters (int): The dimensionality of the output space. activation ( + Optional[Callable]): The activation function to use. Defaults to None. kernel_size (Tuple[int, int, int]): The + size of the 3D transposed convolution window. Defaults to (4, 4, 4). strides (Tuple[int, int, int]): The strides + of the 3D transposed convolution. Defaults to (2, 2, 2). padding (str): The type of padding to use. Defaults to + 'same'. kernel_initializer (Optional[Callable]): The initializer for the kernel weights. Defaults to None. + gamma_initializer (Optional[Callable]): The initializer for the gamma weights of the instance normalization + layer. Defaults to None. use_bias (bool): Whether to include a bias vector in the convolution layer. Defaults to + False. Returns: tf.Tensor: The upsampled tensor with instance normalization applied. """ x = layers.UpSampling3D( size=2 - )(x) + )(x) x = layers.Conv3D( filters, kernel_size, - strides=(1,1,1), + strides=strides, padding=padding, kernel_initializer=kernel_initializer, use_bias=use_bias, @@ -273,4 +276,4 @@ def upsample( x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) if activation: x = activation(x) - return x \ No newline at end of file + return x diff --git a/clDice_func.py b/clDice_func.py index c9c9a80..20dad95 100644 --- a/clDice_func.py +++ b/clDice_func.py @@ -2,15 +2,18 @@ from keras import layers as KL from keras import backend as K -''' https://github.com/jocpae/clDice''' +''' Based on: https://github.com/jocpae/clDice''' def soft_erode(img): - """[This function performs soft-erosion operation on a float32 image] + """ + Perform soft erosion on a given image tensor. + Args: - img ([float32]): [image to be soft eroded] + img (tf.Tensor): Input image tensor on which soft erosion will be performed. + Returns: - [float32]: [the eroded image] + (tf.Tensor): Image tensor after performing soft erosion. """ if len(img.shape) == 4: p2 = -KL.MaxPool2D(pool_size=(3, 1), strides=(1, 1), padding='same', data_format=None)(-img) @@ -24,11 +27,14 @@ def soft_erode(img): def soft_dilate(img): - """[This function performs soft-dilation operation on a float32 image] + """ + Perform soft dilation on a given image tensor. + Args: - img ([float32]): [image to be soft dialated] + img (tf.Tensor): Input image tensor on which soft dilation will be performed. + Returns: - [float32]: [the dialated image] + (tf.Tensor): Image tensor after performing soft dilation. """ if len(img.shape) == 4: return KL.MaxPool2D(pool_size=(3, 3), strides=(1, 1), padding='same', data_format=None)(img) @@ -37,11 +43,14 @@ def soft_dilate(img): def soft_open(img): - """[This function performs soft-open operation on a float32 image] + """ + Perform soft opening on a given image tensor. + Args: - img ([float32]): [image to be soft opened] + img (tf.Tensor): Input image tensor on which soft opening will be performed. + Returns: - [float32]: [image after soft-open] + (tf.Tensor): Image tensor after performing soft opening. """ img = soft_erode(img) img = soft_dilate(img) @@ -49,82 +58,92 @@ def soft_open(img): def soft_skel(img, iters): - """[summary] + """ + Perform soft skeletonisation on a given image tensor. + Args: - img ([float32]): [description] - iters ([int]): [description] + img (tf.Tensor): Input image tensor on which skeletonisation will be performed. + iters (int): Number of iterations for skeletonisation. + Returns: - [float32]: [description] + (tf.Tensor): Skeletonised image tensor after performing soft skeletonisation. """ img1 = soft_open(img) - skel = tf.nn.relu(img-img1) + skel = tf.nn.relu(img - img1) for j in range(iters): img = soft_erode(img) img1 = soft_open(img) - delta = tf.nn.relu(img-img1) + delta = tf.nn.relu(img - img1) intersect = tf.math.multiply(skel, delta) - skel += tf.nn.relu(delta-intersect) + skel += tf.nn.relu(delta - intersect) return skel -def soft_clDice_loss(iter_ = 50): - """[function to compute dice loss] +def soft_clDice_loss(y_true, y_pred, iter_=50): + """ + Compute the soft centre-line (clDice) loss, which is a variant of the Dice loss used in segmentation tasks. + Args: - iter_ (int, optional): [skeletonization iteration]. Defaults to 50. + y_true (tf.Tensor): The ground truth segmentation mask tensor. + y_pred (tf.Tensor): The predicted segmentation mask tensor. + iter_ (int, optional): The number of iterations for skeletonization. Defaults to 50. + + Returns: + (tf.Tensor): The computed soft clDice loss. """ - def loss(y_true, y_pred): - """[function to compute dice loss] - Args: - y_true ([float32]): [ground truth image] - y_pred ([float32]): [predicted image] - Returns: - [float32]: [loss value] - """ - smooth = 1. - skel_pred = soft_skel(y_pred, iter_) - skel_true = soft_skel(y_true, iter_) - pres = (K.sum(tf.math.multiply(skel_pred, y_true))+smooth)/(K.sum(skel_pred)+smooth) - rec = (K.sum(tf.math.multiply(skel_true, y_pred))+smooth)/(K.sum(skel_true)+smooth) - cl_dice = 1.- 2.0*(pres*rec)/(pres+rec) - return cl_dice - return loss + smooth = 1. + skel_pred = soft_skel(y_pred, iter_) + skel_true = soft_skel(y_true, iter_) + pres = (K.sum(tf.math.multiply(skel_pred, y_true)) + smooth) / (K.sum(skel_pred) + smooth) + rec = (K.sum(tf.math.multiply(skel_true, y_pred)) + smooth) / (K.sum(skel_true) + smooth) + cl_dice = 1. - 2.0 * (pres * rec) / (pres + rec) + + return cl_dice def soft_dice(y_true, y_pred): - """[function to compute dice loss] + """ + Compute the soft Dice loss. + Args: - y_true ([float32]): [ground truth image] - y_pred ([float32]): [predicted image] + y_true (tf.Tensor): The ground truth segmentation mask tensor. + y_pred (tf.Tensor): The predicted segmentation mask tensor. + Returns: - [float32]: [loss value] + (tf.Tensor): The computed soft Dice loss. """ smooth = 1 intersection = K.sum((y_true * y_pred)) - coeff = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth) - return (1. - coeff) + coeff = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth) + return 1. - coeff -def soft_dice_cldice_loss(iters = 15, alpha=0.5): - """[function to compute dice+cldice loss] +def soft_dice_cldice_loss(iters=15, alpha=0.5): + """ + Compute the combined soft Dice and clDice loss, a variant of the Dice loss used in segmentation tasks. + Args: - iters (int, optional): [skeletonization iteration]. Defaults to 15. - alpha (float, optional): [weight for the cldice component]. Defaults to 0.5. + iters (int, optional): The number of iterations for skeletonisation. Defaults to 15. + alpha (float, optional): The weight for the clDice component. Defaults to 0.5. + + Returns: + (function): The loss function to be used in training. """ + def loss(y_true, y_pred): - """[summary] + """ + Compute the combined soft Dice and clDice loss for a single batch of data. + Args: - y_true ([float32]): [ground truth image] - y_pred ([float32]): [predicted image] + y_true (tf.Tensor): The ground truth segmentation mask tensor. + y_pred (tf.Tensor): The predicted segmentation mask tensor. + Returns: - [float32]: [loss value] + (tf.Tensor): The computed combined loss value. """ - smooth = 1. - skel_pred = soft_skel(y_pred, iters) - skel_true = soft_skel(y_true, iters) - pres = (K.sum(tf.math.multiply(skel_pred, y_true))+smooth)/(K.sum(skel_pred)+smooth) - rec = (K.sum(tf.math.multiply(skel_true, y_pred))+smooth)/(K.sum(skel_true)+smooth) - cl_dice = 1.- 2.0*(pres*rec)/(pres+rec) + cl_dice = soft_clDice_loss(y_true, y_pred, iters) dice = soft_dice(y_true, y_pred) - return (1.0-alpha)*dice+alpha*cl_dice - return loss \ No newline at end of file + return (1.0 - alpha) * dice + alpha * cl_dice + + return loss diff --git a/custom_callback.py b/custom_callback.py index 5de2e57..ad47683 100644 --- a/custom_callback.py +++ b/custom_callback.py @@ -8,47 +8,50 @@ from tensorflow.keras import layers from utils import min_max_norm -class GAN_Monitor(): + +class GanMonitor: """A callback to generate and save images after each epoch""" - def __init__(self, + + def __init__(self, args, dataset=None, - Alist=None, + Alist=None, Blist=None, process_imaging_domain=None): - + self.imgSize = args.INPUT_IMG_SIZE self.test_AB = dataset.valFullDatasetA self.test_BA = dataset.valFullDatasetB self.Alist = Alist - self.Blist = Blist + self.Blist = Blist self.process_imaging_domain = process_imaging_domain self.period = args.PERIOD_2D_CALLBACK, self.period3D = args.PERIOD_3D_CALLBACK, self.model_path = args.output_dir, self.dims = args.DIMENSIONS - + self.period = self.period[0] self.period3D = self.period3D[0] self.model_path = self.model_path[0] - - def saveModel(self, model, epoch): + + def save_model(self, model, epoch): """Save the trained model at the given epoch. Args: model (object): The VANGAN model object. epoch (int): The epoch number. """ - + # if epoch > 100: - model.gen_AB.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genAB".format(epoch=epoch+1))) - model.gen_BA.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genBA".format(epoch=epoch+1))) - model.disc_A.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discA".format(epoch=epoch+1))) - model.disc_B.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discB".format(epoch=epoch+1))) - - def stitch_subvolumes(self, gen, img, subvol_size, - epoch=-1, stride=(25,25,128), - name=None, output_path=None, complete=False, padFactor=0.25, border_removal=True, process_img=False): + model.gen_AB.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genAB".format(epoch=epoch + 1))) + model.gen_BA.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genBA".format(epoch=epoch + 1))) + model.disc_A.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discA".format(epoch=epoch + 1))) + model.disc_B.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discB".format(epoch=epoch + 1))) + + def stitch_subvolumes(self, gen, img, subvol_size, + epoch=-1, stride=(25, 25, 128), + name=None, output_path=None, complete=False, padFactor=0.25, border_removal=True, + process_img=False): """ Stitch together subvolumes to create a full volume prediction. @@ -79,10 +82,10 @@ def stitch_subvolumes(self, gen, img, subvol_size, stride = list(stride) stride[2] = 1 stride = list(stride) - + if complete: - xspacing = int(padFactor*img.shape[0]) - yspacing = int(padFactor*img.shape[1]) + xspacing = int(padFactor * img.shape[0]) + yspacing = int(padFactor * img.shape[1]) oimgshape = img.shape if stride[2] == 1: if self.dims == 2: @@ -95,48 +98,49 @@ def stitch_subvolumes(self, gen, img, subvol_size, else: img = np.pad(img, ((xspacing, xspacing), (yspacing, yspacing), - (0, 0), + (0, 0), (0, 0)), 'symmetric') else: - zspacing = int(padFactor*img.shape[2]) + zspacing = int(padFactor * img.shape[2]) img = np.pad(img, ((xspacing, xspacing), (yspacing, yspacing), - (zspacing, zspacing), + (zspacing, zspacing), (0, 0)), 'symmetric') - + if self.dims == 2: H, W, D, C = img.shape[0], img.shape[1], 1, img.shape[2] else: H, W, D, C = img.shape[0], img.shape[1], img.shape[2], img.shape[3] kH, kW, kD = subvol_size[1], subvol_size[2], subvol_size[3] - + if not complete or not border_removal: pH, pW, pD = 0, 0, 0 else: - pH, pW, pD = int(0.1*kH), int(0.1*kW), int(0.1*kD) + pH, pW, pD = int(0.1 * kH), int(0.1 * kW), int(0.1 * kD) if kD == D: pD = 0 - + if self.dims == 2: pix_tracker = np.zeros([H, W, C], dtype='float32') else: pix_tracker = np.zeros([H, W, D, C], dtype='float32') pred = np.zeros(img.shape, dtype='float32') - + sh, sw, sd = stride - - dim_out_h = int(np.floor( (H - kH) / sh + 1 )) - dim_out_w = int(np.floor( (W - kW) / sw + 1 )) - dim_out_d = int(np.floor( (D - kD) / sd + 1 )) - + + dim_out_h = int(np.floor((H - kH) / sh + 1)) + dim_out_w = int(np.floor((W - kW) / sw + 1)) + dim_out_d = int(np.floor((D - kD) / sd + 1)) + if complete: - print('\tImage size (X,Y,Z,C): %i x %i x %i x %i' %(oimgshape[0],oimgshape[1],oimgshape[2],oimgshape[3])) - print('\tImage size w/ padding (X,Y,Z,C): %i x %i x %i x %i' %(H,W,D,C)) - print('\tSampling patch size (X,Y,Z,C): %i x %i x %i x %i' %(kH, kW, kD, 1)) - print('\tBorder artefact removal pixel width (X,Y,Z): (%i, %i, %i)' %(pH, pW, pD)) - print('\tStride pixel length (X,Y,Z): (%i, %i, %i)' %(sh, sw, sd)) - print('\tNo. of stiches (X x Y x Z): %i x %i x %i' %(dim_out_h, dim_out_w, dim_out_d)) - + print( + '\tImage size (X,Y,Z,C): %i x %i x %i x %i' % (oimgshape[0], oimgshape[1], oimgshape[2], oimgshape[3])) + print('\tImage size w/ padding (X,Y,Z,C): %i x %i x %i x %i' % (H, W, D, C)) + print('\tSampling patch size (X,Y,Z,C): %i x %i x %i x %i' % (kH, kW, kD, 1)) + print('\tBorder artefact removal pixel width (X,Y,Z): (%i, %i, %i)' % (pH, pW, pD)) + print('\tStride pixel length (X,Y,Z): (%i, %i, %i)' % (sh, sw, sd)) + print('\tNo. of stiches (X x Y x Z): %i x %i x %i' % (dim_out_h, dim_out_w, dim_out_d)) + start_row = 0 end_row = H for i in range(dim_out_h + 1): @@ -146,7 +150,7 @@ def stitch_subvolumes(self, gen, img, subvol_size, start_row = H - kH if end_row < kH: end_row = kH - + for j in range(dim_out_w + 1): start_dep = 0 end_dep = D @@ -154,50 +158,51 @@ def stitch_subvolumes(self, gen, img, subvol_size, start_col = W - kW if end_col < kW: end_col = kW - + for k in range(dim_out_d + 1): if start_dep > D - kD: start_dep = D - kD if end_dep < kD: end_dep = kD - + # From one corner - pix_tracker[start_row+pH:(start_row+kH-pH), start_col+pW:(start_col+kW-pW), start_dep+pD:(start_dep+kD-pD)] += 1. - arr = img[start_row:(start_row+kH), - start_col:(start_col+kW), - start_dep:(start_dep+kD)] - + pix_tracker[start_row + pH:(start_row + kH - pH), start_col + pW:(start_col + kW - pW), + start_dep + pD:(start_dep + kD - pD)] += 1. + arr = img[start_row:(start_row + kH), + start_col:(start_col + kW), + start_dep:(start_dep + kD)] + if process_img == True and self.process_imaging_domain is not None: arr = self.process_imaging_domain(arr) - - arr = gen(np.expand_dims(arr, + + arr = gen(np.expand_dims(arr, axis=0), training=False)[0] - - arr = arr[pH:kH-pH, - pW:kW-pW, - pD:kD-pD] - - pred[start_row+pH:(start_row+kH-pH), - start_col+pW:(start_col+kW-pW), - start_dep+pD:(start_dep+kD-pD)] += arr - - + + arr = arr[pH:kH - pH, + pW:kW - pW, + pD:kD - pD] + + pred[start_row + pH:(start_row + kH - pH), + start_col + pW:(start_col + kW - pW), + start_dep + pD:(start_dep + kD - pD)] += arr + start_dep += sd end_dep -= sd start_col += sw end_col -= sw - start_row += sh + start_row += sh end_row -= sh pred = np.true_divide(pred, pix_tracker) # pred = np.nan_to_num(pred, nan=-1.) - + if complete: if stride[2] == 1: - pred = pred[xspacing:oimgshape[0]+xspacing,yspacing:oimgshape[1]+yspacing,] + pred = pred[xspacing:oimgshape[0] + xspacing, yspacing:oimgshape[1] + yspacing, ] else: - pred = pred[xspacing:oimgshape[0]+xspacing,yspacing:oimgshape[1]+yspacing,zspacing:oimgshape[2]+zspacing,] - + pred = pred[xspacing:oimgshape[0] + xspacing, yspacing:oimgshape[1] + yspacing, + zspacing:oimgshape[2] + zspacing, ] + pred = 255 * min_max_norm(pred) if not complete: @@ -206,20 +211,21 @@ def stitch_subvolumes(self, gen, img, subvol_size, if not complete: if self.dims == 2: pred = np.squeeze(pred) - io.imsave(os.path.join(self.model_path,"e{epoch}_{name}.tiff".format(epoch=epoch+1, name=name)), pred) + io.imsave(os.path.join(self.model_path, "e{epoch}_{name}.tiff".format(epoch=epoch + 1, name=name)), + pred) else: - io.imsave(os.path.join(self.model_path,"e{epoch}_{name}.tiff".format(epoch=epoch+1, name=name)), - np.transpose(pred,(2,0,1,3)), - bigtiff=False, check_contrast=False) + io.imsave(os.path.join(self.model_path, "e{epoch}_{name}.tiff".format(epoch=epoch + 1, name=name)), + np.transpose(pred, (2, 0, 1, 3)), + bigtiff=False, check_contrast=False) else: if self.dims == 2: pred = np.squeeze(pred) - io.imsave(os.path.join(output_path,"{name}.tiff".format(name=name)), pred) + io.imsave(os.path.join(output_path, "{name}.tiff".format(name=name)), pred) else: - io.imsave(os.path.join(output_path,"{name}.tiff".format(name=name)), - np.transpose(pred,(2,0,1,3)), - bigtiff=False, check_contrast=False) - + io.imsave(os.path.join(output_path, "{name}.tiff".format(name=name)), + np.transpose(pred, (2, 0, 1, 3)), + bigtiff=False, check_contrast=False) + def imagePlotter(self, epoch, filename, setlist, dataset, genX, genY, nfig=6, outputFull=True, process_img=False): """ Plot and save 2D sample images during training. @@ -249,11 +255,13 @@ def imagePlotter(self, epoch, filename, setlist, dataset, genX, genY, nfig=6, ou # Generate random crop of sample if self.dims == 2: - sample = tf.expand_dims(tf.image.random_crop(sample, size=(self.imgSize[1], self.imgSize[2], self.imgSize[3])), - axis=0) + sample = tf.expand_dims( + tf.image.random_crop(sample, size=(self.imgSize[1], self.imgSize[2], self.imgSize[3])), + axis=0) else: - sample = tf.expand_dims(tf.image.random_crop(sample, size=(self.imgSize[1], self.imgSize[2], self.imgSize[3], self.imgSize[4])), - axis=0) + sample = tf.expand_dims( + tf.image.random_crop(sample, size=(self.imgSize[1], self.imgSize[2], self.imgSize[3], self.imgSize[4])), + axis=0) if process_img == True and self.process_imaging_domain is not None: sample = self.process_imaging_domain(sample) @@ -266,8 +274,8 @@ def imagePlotter(self, epoch, filename, setlist, dataset, genX, genY, nfig=6, ou prediction = prediction[0].numpy() cycled = cycled[0].numpy() identity = identity[0].numpy() - - _, ax = plt.subplots(nfig+1, 4, figsize=(12, 12)) + + _, ax = plt.subplots(nfig + 1, 4, figsize=(12, 12)) if self.dims == 2: nfig = 1 ax[0, 0].imshow(sample, cmap='gray') @@ -285,10 +293,10 @@ def imagePlotter(self, epoch, filename, setlist, dataset, genX, genY, nfig=6, ou else: for j in range(nfig): - ax[j, 0].imshow(sample[:,:,j*int(sample.shape[2]/nfig),0], cmap='gray') - ax[j, 1].imshow(prediction[:,:,j*int(sample.shape[2]/nfig),0], cmap='gray') - ax[j, 2].imshow(cycled[:,:,j*int(sample.shape[2]/nfig),0], cmap='gray') - ax[j, 3].imshow(identity[:,:,j*int(sample.shape[2]/nfig),0], cmap='gray') + ax[j, 0].imshow(sample[:, :, j * int(sample.shape[2] / nfig), 0], cmap='gray') + ax[j, 1].imshow(prediction[:, :, j * int(sample.shape[2] / nfig), 0], cmap='gray') + ax[j, 2].imshow(cycled[:, :, j * int(sample.shape[2] / nfig), 0], cmap='gray') + ax[j, 3].imshow(identity[:, :, j * int(sample.shape[2] / nfig), 0], cmap='gray') ax[j, 0].set_title("Input image") ax[j, 1].set_title("Translated image") ax[j, 2].set_title("Cycled image") @@ -297,24 +305,28 @@ def imagePlotter(self, epoch, filename, setlist, dataset, genX, genY, nfig=6, ou ax[j, 1].axis("off") ax[j, 2].axis("off") ax[j, 3].axis("off") - ax[nfig,0].hist(sample.ravel(), bins=256, range=(np.amin(sample),np.amax(sample)), fc='k', ec='k', density=True) - ax[nfig,1].hist(prediction.ravel(), bins=256, range=(np.amin(prediction),np.amax(prediction)), fc='k', ec='k', density=True) - ax[nfig,2].hist(cycled.ravel(), bins=256, range=(np.amin(cycled),np.amax(cycled)), fc='k', ec='k', density=True) - ax[nfig,3].hist(identity.ravel(), bins=256, range=(np.amin(identity),np.amax(identity)), fc='k', ec='k', density=True) - - plt.savefig("./GANMonitor/{epoch}_{genID}.png".format(epoch=epoch+1, - genID=filename), + ax[nfig, 0].hist(sample.ravel(), bins=256, range=(np.amin(sample), np.amax(sample)), fc='k', ec='k', + density=True) + ax[nfig, 1].hist(prediction.ravel(), bins=256, range=(np.amin(prediction), np.amax(prediction)), fc='k', ec='k', + density=True) + ax[nfig, 2].hist(cycled.ravel(), bins=256, range=(np.amin(cycled), np.amax(cycled)), fc='k', ec='k', + density=True) + ax[nfig, 3].hist(identity.ravel(), bins=256, range=(np.amin(identity), np.amax(identity)), fc='k', ec='k', + density=True) + + plt.savefig("./GANMonitor/{epoch}_{genID}.png".format(epoch=epoch + 1, + genID=filename), dpi=300) - + plt.tight_layout() plt.show(block=False) plt.close() - + # Generate 3D predictions, stitch and save - if epoch % self.period3D == 1 and outputFull:# and epoch > 160: - self.stitch_subvolumes(genX, storeSample.numpy(), - self.imgSize, epoch=epoch, name=sampleName, process_img=process_img) - + if epoch % self.period3D == 1 and outputFull: # and epoch > 160: + self.stitch_subvolumes(genX, storeSample.numpy(), + self.imgSize, epoch=epoch, name=sampleName, process_img=process_img) + def set_learning_rate(self, model, epoch, args): """ Sets the learning rate for each optimizer based on the current epoch. @@ -330,59 +342,64 @@ def set_learning_rate(self, model, epoch, args): Returns: None """ - + if epoch == args.INITIATE_LR_DECAY: - - model.gen_A_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=args.INITIAL_LR, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY)*args.train_steps, - end_learning_rate=0, - power=1) - - model.gen_B_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=args.INITIAL_LR, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY)*args.train_steps, - end_learning_rate=0, - power=1) - - model.disc_A_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=args.INITIAL_LR, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY)*args.train_steps, - end_learning_rate=0, - power=1) - - model.disc_B_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=args.INITIAL_LR, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY)*args.train_steps, - end_learning_rate=0, - power=1) - + model.gen_I_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=args.INITIAL_LR, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY) * args.train_steps, + end_learning_rate=0, + power=1) + + model.gen_S_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=args.INITIAL_LR, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY) * args.train_steps, + end_learning_rate=0, + power=1) + + model.disc_I_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=args.INITIAL_LR, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY) * args.train_steps, + end_learning_rate=0, + power=1) + + model.disc_S_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=args.INITIAL_LR, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY) * args.train_steps, + end_learning_rate=0, + power=1) + if model.checkpoint_loaded and epoch > args.INITIATE_LR_DECAY: - model.checkpoint_loaded = False - + learning_gradient = args.INITIAL_LR / (args.EPOCHS - args.INITIATE_LR_DECAY) intermediate_learning_rate = learning_gradient * (args.EPOCHS - epoch) - - print('Initial learning rate: %0.8f' %intermediate_learning_rate) - - model.gen_A_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=intermediate_learning_rate, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY-epoch)*args.train_steps, - end_learning_rate=0, - power=1) - - model.gen_B_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=intermediate_learning_rate, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY-epoch)*args.train_steps, - end_learning_rate=0, - power=1) - - model.disc_A_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=intermediate_learning_rate, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY-epoch)*args.train_steps, - end_learning_rate=0, - power=1) - - model.disc_B_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=intermediate_learning_rate, - decay_steps=(args.EPOCHS-args.INITIATE_LR_DECAY-epoch)*args.train_steps, - end_learning_rate=0, - power=1) - - + + print('Initial learning rate: %0.8f' % intermediate_learning_rate) + + model.gen_I_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=intermediate_learning_rate, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY - epoch) * args.train_steps, + end_learning_rate=0, + power=1) + + model.gen_S_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=intermediate_learning_rate, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY - epoch) * args.train_steps, + end_learning_rate=0, + power=1) + + model.disc_I_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=intermediate_learning_rate, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY - epoch) * args.train_steps, + end_learning_rate=0, + power=1) + + model.disc_S_optimizer.lr = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=intermediate_learning_rate, + decay_steps=(args.EPOCHS - args.INITIATE_LR_DECAY - epoch) * args.train_steps, + end_learning_rate=0, + power=1) + def updateDiscriminatorNoise(self, model, init_noise, epoch, args): """ Update the standard deviation of the Gaussian noise layer in a VANGAN discriminator. @@ -403,14 +420,14 @@ def updateDiscriminatorNoise(self, model, init_noise, epoch, args): decay_rate = epoch / args.NO_NOISE noise = init_noise * (1. - decay_rate) # noise = 0.9 ** (epoch + 1) - print('Noise std: %0.5f' %noise) + print('Noise std: %0.5f' % noise) for layer in model.layers: if type(layer) == layers.GaussianNoise: if noise > 0.: layer.stddev = noise else: - layer.stddev = 0.0 - + layer.stddev = 0.0 + def on_epoch_start(self, model, epoch, args, logs=None): """ Callback function that is called at the start of each training epoch. @@ -425,12 +442,11 @@ def on_epoch_start(self, model, epoch, args, logs=None): None """ - + self.set_learning_rate(model, epoch, args) - - self.updateDiscriminatorNoise(model.disc_A, model.layer_noise, epoch, args) - self.updateDiscriminatorNoise(model.disc_B, model.layer_noise, epoch, args) - + + self.updateDiscriminatorNoise(model.disc_I, model.layer_noise, epoch, args) + self.updateDiscriminatorNoise(model.disc_S, model.layer_noise, epoch, args) def on_epoch_end(self, model, epoch, logs=None): """ @@ -447,10 +463,11 @@ def on_epoch_end(self, model, epoch, logs=None): """ # Generate 2D plots - self.imagePlotter(epoch, "genAB", self.Alist, self.test_AB, model.gen_AB, model.gen_BA, process_img=True) - self.imagePlotter(epoch, "genBA", self.Blist, self.test_BA, model.gen_BA, model.gen_AB, outputFull=True) - - def run_mapping(self, model, test_set, sub_img_size=(64,64,512,1), segmentation=True, stride=(25,25,1), padFactor=0.25, filetext=None, filepath=''): + self.imagePlotter(epoch, "genAB", self.Alist, self.test_AB, model.gen_IS, model.gen_SI, process_img=True) + self.imagePlotter(epoch, "genBA", self.Blist, self.test_BA, model.gen_SI, model.gen_IS, outputFull=True) + + def run_mapping(self, model, test_set, sub_img_size=(64, 64, 512, 1), segmentation=True, stride=(25, 25, 1), + padFactor=0.25, filetext=None, filepath=''): """ Runs mapping on a set of test images using the specified generator model and sub-volume size. @@ -468,31 +485,28 @@ def run_mapping(self, model, test_set, sub_img_size=(64,64,512,1), segmentation= None """ - + # num_cores = int(0.8*(multiprocessing.cpu_count() - 1)) # print('Processing training data ...') # Parallel(n_jobs=num_cores, verbose=50)(delayed( - # self.stitch_subvolumes)(gen=model.gen_AB, + # self.stitch_subvolumes)(gen=model.gen_IS, # img=np.load(test_set[imgdir]), # subvol_size=sub_img_size, # name=filetext+os.path.splitext(os.path.split(os.path.basename(test_set[imgdir]))[1])[0], # complete=True) for imgdir in range(len(test_set))) - + for imgdir in range(len(test_set)): # Extract test array and filename img = np.load(test_set[imgdir]) filename = os.path.basename(test_set[imgdir]) filename = os.path.splitext(os.path.split(filename)[1])[0] if segmentation: - print('Segmenting %s ... (%i / %i)' %(filename, imgdir+1, len(test_set))) + print('Segmenting %s ... (%i / %i)' % (filename, imgdir + 1, len(test_set))) # Generate segmentations, stitch and save - self.stitch_subvolumes(model.gen_AB, img, sub_img_size, name=filetext+filename, output_path=filepath, - complete=True, stride=stride, padFactor=padFactor) + self.stitch_subvolumes(model.gen_IS, img, sub_img_size, name=filetext + filename, output_path=filepath, + complete=True, stride=stride, padFactor=padFactor) else: - print('Mapping %s ... (%i / %i)' %(filename, imgdir+1, len(test_set))) + print('Mapping %s ... (%i / %i)' % (filename, imgdir + 1, len(test_set))) # Generate segmentations, stitch and save - self.stitch_subvolumes(model.gen_BA, img, sub_img_size, name=filetext+filename, output_path=filepath, - complete=True, process_img=True, stride=stride, padFactor=padFactor) - - - + self.stitch_subvolumes(model.gen_SI, img, sub_img_size, name=filetext + filename, output_path=filepath, + complete=True, process_img=True, stride=stride, padFactor=padFactor) diff --git a/dataset.py b/dataset.py index 046c464..bb8190f 100644 --- a/dataset.py +++ b/dataset.py @@ -5,12 +5,13 @@ import matplotlib.pyplot as plt from skimage import io -class dataset_gen: + +class DatasetGen: def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.distribute.Strategy, otf_imaging=None): ''' Setting shard policy for distributed dataset ''' options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA - + ''' Setting parameters for below ''' if args.DIMENSIONS == 2: self.imaging_output_shapes = (None, None, args.CHANNELS) @@ -20,9 +21,11 @@ def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.dist else: self.imaging_output_shapes = (None, None, None, args.CHANNELS) self.segmentation_output_shapes = (None, None, None, 1) - self.imaging_patch_shape = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], args.CHANNELS) - self.segmentation_patch_shape = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], 1) - + self.imaging_patch_shape = ( + args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], args.CHANNELS) + self.segmentation_patch_shape = ( + args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], 1) + self.strategy = strategy self.pathA = imaging_domain_data self.pathB = seg_domain_data @@ -30,53 +33,60 @@ def __init__(self, args, imaging_domain_data, seg_domain_data, strategy: tf.dist self.otf_imaging = otf_imaging self.IMAGE_THRESH = 0.5 self.SEG_THRESH = 0.8 - + ''' Create datasets ''' with self.strategy.scope(): - self.trainDatasetA = tf.data.Dataset.from_generator(lambda: self.datagenA('training'), - output_types=tf.float32, - output_shapes=self.imaging_output_shapes) - self.trainDatasetA = self.trainDatasetA.map(self.processImagingDomain, num_parallel_calls=tf.data.AUTOTUNE) - self.trainDatasetA = self.trainDatasetA.repeat() - self.trainDatasetA = self.trainDatasetA.with_options(options) - - self.trainDatasetB = tf.data.Dataset.from_generator(lambda: self.datagenB('training'), - output_types=tf.float32, - output_shapes=self.segmentation_output_shapes) - self.trainDatasetB = self.trainDatasetB.map(self.processSegDomain, num_parallel_calls=tf.data.AUTOTUNE) - self.trainDatasetB = self.trainDatasetB.repeat() - self.trainDatasetB = self.trainDatasetB.with_options(options) - - self.valDatasetA = tf.data.Dataset.from_generator(lambda: self.datagenA('validation'), - output_types=tf.float32, - output_shapes=self.imaging_output_shapes) - self.valDatasetA = self.valDatasetA.map(map_func = self.processImagingDomain, num_parallel_calls=tf.data.AUTOTUNE) - self.valDatasetA = self.valDatasetA.repeat() - self.valDatasetA = self.valDatasetA.with_options(options) - - self.valDatasetB = tf.data.Dataset.from_generator(lambda: self.datagenB('validation'), - output_types=tf.float32, - output_shapes=self.segmentation_output_shapes) - self.valDatasetB = self.valDatasetB.map(map_func = self.processSegDomain, num_parallel_calls=tf.data.AUTOTUNE) - self.valDatasetB = self.valDatasetB.repeat() - self.valDatasetB = self.valDatasetB.with_options(options) - + self.imaging_train_dataset = tf.data.Dataset.from_generator(lambda: self.datagenA('training'), + output_types=tf.float32, + output_shapes=self.imaging_output_shapes) + self.imaging_train_dataset = self.imaging_train_dataset.map(self.processImagingDomain, + num_parallel_calls=tf.data.AUTOTUNE) + self.imaging_train_dataset = self.imaging_train_dataset.repeat() + self.imaging_train_dataset = self.imaging_train_dataset.with_options(options) + + self.segmentation_train_dataset = tf.data.Dataset.from_generator(lambda: self.datagenB('training'), + output_types=tf.float32, + output_shapes=self.segmentation_output_shapes) + self.segmentation_train_dataset = self.segmentation_train_dataset.map(self.processSegDomain, + num_parallel_calls=tf.data.AUTOTUNE) + self.segmentation_train_dataset = self.segmentation_train_dataset.repeat() + self.segmentation_train_dataset = self.segmentation_train_dataset.with_options(options) + + self.imaging_val_dataset = tf.data.Dataset.from_generator(lambda: self.datagenA('validation'), + output_types=tf.float32, + output_shapes=self.imaging_output_shapes) + self.imaging_val_dataset = self.imaging_val_dataset.map(map_func=self.processImagingDomain, + num_parallel_calls=tf.data.AUTOTUNE) + self.imaging_val_dataset = self.imaging_val_dataset.repeat() + self.imaging_val_dataset = self.imaging_val_dataset.with_options(options) + + self.segmentation_val_dataset = tf.data.Dataset.from_generator(lambda: self.datagenB('validation'), + output_types=tf.float32, + output_shapes=self.segmentation_output_shapes) + self.segmentation_val_dataset = self.segmentation_val_dataset.map(map_func=self.processSegDomain, + num_parallel_calls=tf.data.AUTOTUNE) + self.segmentation_val_dataset = self.segmentation_val_dataset.repeat() + self.segmentation_val_dataset = self.segmentation_val_dataset.with_options(options) + self.plotSampleDataset() - + self.valFullDatasetA = tf.data.Dataset.from_generator(self.valDatagenA, output_types=(tf.float32, tf.int8)) self.valFullDatasetB = tf.data.Dataset.from_generator(self.valDatagenB, output_types=(tf.float32, tf.int8)) - - self.train_ds = tf.data.Dataset.zip((self.trainDatasetA.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True), - self.trainDatasetB.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True))).prefetch(tf.data.AUTOTUNE) - self.val_ds = tf.data.Dataset.zip((self.valDatasetA.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True), - self.valDatasetB.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True))).prefetch(tf.data.AUTOTUNE) - self.train_ds = self.strategy.experimental_distribute_dataset(self.train_ds) - self.val_ds = self.strategy.experimental_distribute_dataset(self.val_ds) - + self.train_dataset = tf.data.Dataset.zip( + (self.imaging_train_dataset.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True), + self.segmentation_train_dataset.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True))).prefetch( + tf.data.AUTOTUNE) + self.val_dataset = tf.data.Dataset.zip( + (self.imaging_val_dataset.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True), + self.segmentation_val_dataset.batch(args.GLOBAL_BATCH_SIZE, drop_remainder=True))).prefetch( + tf.data.AUTOTUNE) + + self.train_dataset = self.strategy.experimental_distribute_dataset(self.train_dataset) + self.val_dataset = self.strategy.experimental_distribute_dataset(self.val_dataset) - ''' Functions to gather imaging subvolumes ''' + def datagenA(self, typ='training'): """ Generates a batch of data from the pathA directory. @@ -94,16 +104,16 @@ def datagenA(self, typ='training'): if iterA >= math.floor(len(datasetA) // self.args.GLOBAL_BATCH_SIZE): iterA = 0 np.random.shuffle(datasetA) - - file = datasetA[iterA*self.args.GLOBAL_BATCH_SIZE:(iterA+1)*self.args.GLOBAL_BATCH_SIZE] - + + file = datasetA[iterA * self.args.GLOBAL_BATCH_SIZE:(iterA + 1) * self.args.GLOBAL_BATCH_SIZE] + # Load batch of full size images for idx, filename in enumerate(file): - yield tf.convert_to_tensor(np.rot90(np.load(filename), + yield tf.convert_to_tensor(np.rot90(np.load(filename), np.random.choice([-1, 0, 1])), dtype=tf.float32) - + iterA += 1 - + def datagenB(self, typ='training'): """ Generates a batch of data from the pathB directory. @@ -121,122 +131,121 @@ def datagenB(self, typ='training'): if iterB >= math.floor(len(datasetB) // self.args.GLOBAL_BATCH_SIZE): iterB = 0 np.random.shuffle(datasetB) - - file = datasetB[iterB*self.args.GLOBAL_BATCH_SIZE:(iterB+1)*self.args.GLOBAL_BATCH_SIZE] - + + file = datasetB[iterB * self.args.GLOBAL_BATCH_SIZE:(iterB + 1) * self.args.GLOBAL_BATCH_SIZE] + # Load batch of full size images for idx, filename in enumerate(file): - yield tf.convert_to_tensor(np.rot90(np.load(filename), + yield tf.convert_to_tensor(np.rot90(np.load(filename), np.random.choice([-1, 0, 1])), dtype=tf.float32) - + iterB += 1 - + def valDatagenA(self): while True: - i = random.randint(0,self.pathA['validation'].shape[0]-1) - yield (tf.convert_to_tensor(np.load(self.pathA['validation'][i]) , dtype=tf.float32), i) - + i = random.randint(0, self.pathA['validation'].shape[0] - 1) + yield tf.convert_to_tensor(np.load(self.pathA['validation'][i]), dtype=tf.float32), i + def valDatagenB(self): while True: - i = random.randint(0,self.pathB['validation'].shape[0]-1) - yield (tf.convert_to_tensor(np.load(self.pathB['validation'][i]), dtype=tf.float32), i) - - + i = random.randint(0, self.pathB['validation'].shape[0] - 1) + yield tf.convert_to_tensor(np.load(self.pathB['validation'][i]), dtype=tf.float32), i + ''' Functions for data preprocessing ''' + def body(self, arr, image): return [tf.image.random_crop(image, size=self.segmentation_patch_shape), image] - + def imagingCondition(self, arr, image): return tf.math.less(tf.math.reduce_max(arr), self.IMAGE_THRESH) - + def segmentationCondition(self, arr, image): return tf.math.less(tf.math.reduce_max(arr), self.SEG_THRESH) - + def processImagingDomain(self, image): ''' Standardizes image data and creates subvolumes ''' subvol = tf.image.random_crop(image, size=self.imaging_patch_shape) if self.otf_imaging is not None: subvol = self.otf_imaging(subvol) return subvol - + def processSegDomain(self, image): arr = tf.image.random_crop(value=image, size=self.segmentation_patch_shape) arr, _ = tf.while_loop(self.segmentationCondition, self.body, [arr, image], maximum_iterations=10) return arr - + def plotSampleDataset(self): """ Plots a sample of the input datasets A and B along with their histograms. The function saves a 3D TIFF file of the input data. Args: - - self.trainDatasetA: Dataset A. - - self.trainDatasetB: Dataset B. + - self.imaging_train_dataset: Dataset A. + - self.segmentation_train_dataset: Dataset B. - self.args.DIMENSIONS: Dimensionality of the input data. - self.args.SUBVOL_PATCH_SIZE: Size of the subvolume patch. Returns: - None """ - + # Visualise some examples if self.args.DIMENSIONS == 2: nfig = 1 else: nfig = 6 - - fig, axs = plt.subplots(nfig+1, 2, figsize=(10, 15)) + + fig, axs = plt.subplots(nfig + 1, 2, figsize=(10, 15)) fig.subplots_adjust(hspace=0.5) - for i, samples in enumerate(zip(self.trainDatasetA.take(1), self.trainDatasetB.take(1))): - + for i, samples in enumerate(zip(self.imaging_train_dataset.take(1), self.segmentation_train_dataset.take(1))): + dA = samples[0].numpy() dB = samples[1].numpy() - + if self.args.DIMENSIONS == 3: ''' Save 3D images ''' - io.imsave("./GANMonitor/Test_Input_A.tiff", - np.transpose(dA,(2,0,1,3)), - bigtiff=False, check_contrast=False) - - io.imsave("./GANMonitor/Test_Input_B.tiff", - np.transpose(dB,(2,0,1,3)), - bigtiff=False, check_contrast=False) - + io.imsave("./GANMonitor/Test_Input_A.tiff", + np.transpose(dA, (2, 0, 1, 3)), + bigtiff=False, check_contrast=False) + + io.imsave("./GANMonitor/Test_Input_B.tiff", + np.transpose(dB, (2, 0, 1, 3)), + bigtiff=False, check_contrast=False) + if self.args.DIMENSIONS == 2: showA = (dA * 127.5 + 127.5).astype('uint8') showB = (dB * 127.5 + 127.5).astype('uint8') axs[0, 0].imshow(showA, cmap='gray') axs[0, 1].imshow(showB, cmap='gray') else: - for j in range(0,nfig): - showA = (dA[:,:,j*int(self.args.SUBVOL_PATCH_SIZE[2]/nfig),]) - showB = (dB[:,:,j*int(self.args.SUBVOL_PATCH_SIZE[2]/nfig),]) + for j in range(0, nfig): + showA = (dA[:, :, j * int(self.args.SUBVOL_PATCH_SIZE[2] / nfig), ]) + showB = (dB[:, :, j * int(self.args.SUBVOL_PATCH_SIZE[2] / nfig), ]) axs[j, 0].imshow(showA, cmap='gray') axs[j, 1].imshow(showB, cmap='gray') - + ''' Include histograms ''' - axs[nfig,0].hist(dA.ravel(), bins=256, range=(np.amin(dA),np.amax(dA)), fc='k', ec='k', density=True) - axs[nfig,1].hist(dB.ravel(), bins=256, range=(np.amin(dB),np.amax(dB)), fc='k', ec='k', density=True) - + axs[nfig, 0].hist(dA.ravel(), bins=256, range=(np.amin(dA), np.amax(dA)), fc='k', ec='k', density=True) + axs[nfig, 1].hist(dB.ravel(), bins=256, range=(np.amin(dB), np.amax(dB)), fc='k', ec='k', density=True) + # Set axis labels axs[0, 0].set_title('Dataset A Example (XY Slices)') axs[0, 1].set_title('Dataset B Example (XY Slices)') axs[nfig, 0].set_ylabel('Voxel Frequency') plt.show(block=False) plt.close() - - + if self.args.DIMENSIONS == 3: _, axs = plt.subplots(nfig, 2, figsize=(10, 15)) - for j in range(0,nfig): - showA = dA[:,j*int(self.args.SUBVOL_PATCH_SIZE[1]/nfig),:,0] - showB = dB[:,j*int(self.args.SUBVOL_PATCH_SIZE[1]/nfig),:self.args.SUBVOL_PATCH_SIZE[2]-1,0] + for j in range(0, nfig): + showA = dA[:, j * int(self.args.SUBVOL_PATCH_SIZE[1] / nfig), :, 0] + showB = dB[:, j * int(self.args.SUBVOL_PATCH_SIZE[1] / nfig), :self.args.SUBVOL_PATCH_SIZE[2] - 1, + 0] axs[j, 0].imshow(showA, cmap='gray') axs[j, 1].imshow(showB, cmap='gray') - + # Set axis labels axs[0, 0].set_title('Dataset A Example (YZ Slices)') axs[0, 1].set_title('Dataset B Example (YZ Slices)') plt.show(block=False) plt.close() - diff --git a/discriminator.py b/discriminator.py index bd332c4..df14c1f 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,22 +1,22 @@ +import tensorflow_addons as tfa from tensorflow import keras from tensorflow.keras import layers from building_blocks import downsample, ReflectionPadding3D -import tensorflow_addons as tfa -import tensorflow as tf + def get_discriminator( - input_img_size=(64, 64, 512, 1), - batch_size=None, - filters=64, - kernel_initializer=None, - num_downsampling=3, - use_dropout=False, - wasserstein=False, - use_SN=False, - use_input_noise=False, - use_layer_noise=False, - name=None, - noise_std=0.1 + input_img_size=(64, 64, 512, 1), + batch_size=None, + filters=64, + kernel_initializer='he_normal', + num_downsampling=3, + use_dropout=False, + wasserstein=False, + use_SN=False, + use_input_noise=False, + use_layer_noise=False, + name=None, + noise_std=0.1 ): """ Creates a discriminator model for a 3D volumetric image using convolutional layers. @@ -29,11 +29,11 @@ def get_discriminator( - kernel_initializer: The initializer for the convolutional kernels. Default is None. - num_downsampling: Int, the number of times to downsample the input image with convolutional layers. Default is 3. - - use_dropout: Bool, whether or not to use dropout in the model. Default is False. - - wasserstein: Bool, whether or not the model is a Wasserstein GAN. Default is False. - - use_SN: Bool, whether or not to use spectral normalization in the convolutional layers. Default is False. - - use_input_noise: Bool, whether or not to add Gaussian noise to the input image. Default is False. - - use_layer_noise: Bool, whether or not to add Gaussian noise to the convolutional layers. Default is False. + - use_dropout: Bool, whether to use dropout in the model. Default is False. + - wasserstein: Bool, whether the model is a Wasserstein GAN. Default is False. + - use_spec_norm: Bool, whether to use spectral normalization in the convolutional layers. Default is False. + - use_input_noise: Bool, whether to add Gaussian noise to the input image. Default is False. + - use_layer_noise: Bool, whether to add Gaussian noise to the convolutional layers. Default is False. - name: String, name for the model. Default is None. - noise_std: Float, the standard deviation of the Gaussian noise to add to the input and/or convolutional layers. Default is 0.1. @@ -48,7 +48,6 @@ def get_discriminator( x = ReflectionPadding3D()(img_input) if use_input_noise: x = layers.GaussianNoise(noise_std)(x) - if use_SN: x = tfa.layers.SpectralNormalization(layers.Conv3D( @@ -67,7 +66,7 @@ def get_discriminator( kernel_initializer=kernel_initializer, )(x) x = tfa.layers.InstanceNormalization(gamma_initializer=None)(x) - + x = layers.LeakyReLU(0.2)(x) num_filters = filters @@ -81,7 +80,7 @@ def get_discriminator( kernel_size=(4, 4, 4), strides=(2, 2, 2), use_dropout=use_dropout, - use_SN=use_SN, + use_spec_norm=use_SN, use_layer_noise=use_layer_noise, noise_std=noise_std ) @@ -94,14 +93,14 @@ def get_discriminator( strides=(1, 1, 1), use_dropout=use_dropout, padding='same', - use_SN=use_SN, + use_spec_norm=use_SN, use_layer_noise=use_layer_noise, noise_std=noise_std ) if use_layer_noise: x = layers.GaussianNoise(noise_std)(x) - + x = layers.Conv3D( 1, (3, 3, 3), @@ -109,7 +108,7 @@ def get_discriminator( padding="same", kernel_initializer=kernel_initializer, )(x) - + if wasserstein: x = layers.Flatten()(x) x = layers.Dropout(0.2)(x) @@ -119,4 +118,3 @@ def get_discriminator( model.summary() return model - diff --git a/generator.py b/generator.py index 4fac407..9590f06 100644 --- a/generator.py +++ b/generator.py @@ -1,38 +1,35 @@ +import tensorflow_addons as tfa from tensorflow import keras from tensorflow.keras import layers +from building_blocks import downsample, upsample, residual_block, ReflectionPadding3D -import tensorflow_addons as tfa - -from building_blocks import downsample, deconv, upsample, residual_block, ReflectionPadding3D def get_resnet_generator( - input_img_size=(64,64,512,1), - batch_size=None, - filters=32, - num_downsampling_blocks=2, - num_residual_blocks=6, - num_upsample_blocks=2, - gamma_initializer='he_normal', - kernel_initializer='he_normal', - name=None, + input_img_size=(64, 64, 512, 1), + batch_size=None, + filters=32, + num_downsampling_blocks=2, + num_residual_blocks=6, + num_upsample_blocks=2, + gamma_initializer='he_normal', + kernel_initializer='he_normal', + name=None, ): """ Returns a 3D ResNet generator model. - Args: - input_img_size (tuple): The size of the input image (height, width, depth, channels). - batch_size (int, optional): The batch size to be used for the model. Defaults to None. - filters (int, optional): The number of filters in the first convolutional layer. Defaults to 32. - num_downsampling_blocks (int, optional): The number of downsampling blocks in the generator. Defaults to 2. - num_residual_blocks (int, optional): The number of residual blocks in the generator. Defaults to 6. - num_upsample_blocks (int, optional): The number of upsampling blocks in the generator. Defaults to 2. - gamma_initializer (str, optional): The initializer to be used for the instance normalization gamma. Defaults to 'he_normal'. - kernel_initializer (str, optional): The initializer to be used for the convolutional kernels. Defaults to 'he_normal'. - name (str, optional): The name of the model. Defaults to None. + Args: input_img_size (tuple): The size of the input image (height, width, depth, channels). batch_size (int, + optional): The batch size to be used for the model. Defaults to None. filters (int, optional): The number of + filters in the first convolutional layer. Defaults to 32. num_downsampling_blocks (int, optional): The number of + downsampling blocks in the generator. Defaults to 2. num_residual_blocks (int, optional): The number of residual + blocks in the generator. Defaults to 6. num_upsample_blocks (int, optional): The number of upsampling blocks in + the generator. Defaults to 2. gamma_initializer (str, optional): The initializer to be used for the instance + normalization gamma. Defaults to 'he_normal'. kernel_initializer (str, optional): The initializer to be used for + the convolutional kernels. Defaults to 'he_normal'. name (str, optional): The name of the model. Defaults to None. Returns: tensorflow.keras.models.Model: The 3D ResNet generator model. - """ + """ img_input = layers.Input(shape=input_img_size, batch_size=batch_size, name=name + "_img_input") x = ReflectionPadding3D(padding=(1, 1, 1))(img_input) @@ -74,4 +71,3 @@ def get_resnet_generator( model.summary() return model - diff --git a/loss_functions.py b/loss_functions.py index 606b69e..03facbe 100644 --- a/loss_functions.py +++ b/loss_functions.py @@ -37,29 +37,29 @@ def MSLE(self, real, fake): def MAE(self, y_true, y_pred): """ Compute the per-sample mean absolute error (MAE) between the true and predicted tensors. - + Args: - y_true: A tensor of true values. - y_pred: A tensor of predicted values. - + Returns: - A scalar tensor representing the per-sample MAE between the true and predicted tensors. """ - return reduce_mean(self, tf.abs(y_true - y_pred), axis=list(range(1, len(y_true.shape)))) + return reduce_mean(self, tf.abs(y_true - y_pred), axis=list(range(1, len(y_true.shape)))) @tf.function def MSE(self, y_true, y_pred): """ Compute the per-sample mean squared error (MSE) between the true and predicted tensors. - + Args: - y_true: A tensor of true values. - y_pred: A tensor of predicted values. - + Returns: - A scalar tensor representing the per-sample MSE between the true and predicted tensors. """ - return reduce_mean(self, tf.square(y_true - y_pred), axis=list(range(1, len(y_true.shape)))) + return reduce_mean(self, tf.square(y_true - y_pred), axis=list(range(1, len(y_true.shape)))) @tf.function def L4(self, y_true, y_pred): @@ -143,16 +143,16 @@ def cycle_loss(self, real_image, cycled_image, typ=None): return reduce_mean(self, loss_obj(real, cycled)) * self.lambda_cycle @tf.function -def cycle_perceptual(self, real_image, cycled_image): +def cycle_reconstruction(self, real_image, cycled_image): """ - Return the per sample cycle perceptual loss using Structural Similarity Index (SSIM) loss + Return the per sample cycle reconstruction loss using Structural Similarity Index (SSIM) loss Args: - real_image: Tensor, shape (batch_size, H, W, C), representing the real image - cycled_image: Tensor, shape (batch_size, H, W, C), representing the cycled image Returns: - - loss: float Tensor, representing the per sample cycle perceptual loss + - loss: float Tensor, representing the per sample cycle reconstruction loss """ real = min_max_norm_tf(real_image) cycled = min_max_norm_tf(cycled_image) diff --git a/main.py b/main.py index f11d5cb..b19aadd 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,23 @@ import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private' - import shutil import glob import argparse import numpy as np import scipy.stats as sp import tensorflow as tf -tf.keras.backend.clear_session() - from time import time from vangan import VanGan, train -from custom_callback import GAN_Monitor -from dataset import dataset_gen -from preprocessing import DataPrepocessor +from custom_callback import GanMonitor +from dataset import DatasetGen +from preprocessing import DataPreprocessor from tb_callback import TB_Summary from utils import min_max_norm_tf, rescale_arr_tf, z_score_norm from post_training import epoch_sweep +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private' + +tf.keras.backend.clear_session() print('*** Setting up GPU ***') ''' SET GPU MEMORY USAGE ''' @@ -27,7 +26,6 @@ for i in range(len(physical_devices)): tf.config.experimental.set_memory_growth(physical_devices[i], True) - ''' SET TF GPU STRATEGY ''' strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1', 'GPU:2', 'GPU:3']) # strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() @@ -43,7 +41,7 @@ ''' ORGANISE TENSORBOARD OUTPUT FOLDERS ''' print('*** Organising tensorboard folders ***') -tensorboardDir = 'logs' +tensorboardDir = 'TB_Logs' monitorDir = 'GANMonitor' if os.path.isdir(tensorboardDir): shutil.rmtree(tensorboardDir) @@ -55,9 +53,8 @@ os.remove(f) else: os.makedirs(monitorDir) - -summary = TB_Summary('logs/') # Initialise TensorBoard summary helper +summary = TB_Summary(tensorboardDir) # Initialise TensorBoard summary helper ''' SET PARAMETERS ''' print('*** Setting VANGAN parameters ***') @@ -70,23 +67,21 @@ # Training parameters args.EPOCHS = 200 -args.BATCH_SIZE = 4 +args.BATCH_SIZE = 3 args.GLOBAL_BATCH_SIZE = args.N_DEVICES * args.BATCH_SIZE args.PREFETCH_SIZE = 4 -args.INITIAL_LR = 2e-4 # Learning rate -args.INITIATE_LR_DECAY = 0.5 * args.EPOCHS # Set start of learning rate decay to 0 -args.NO_NOISE = args.EPOCHS # Set when discriminator noise decays to 0 -args.KERNEL_INIT = tf.keras.initializers.HeNormal() # Weights initializer for the layers. -args.GAMMA_INIT = tf.keras.initializers.HeNormal() # Gamma initializer for instance normalization. +args.INITIAL_LR = 2e-4 # Learning rate +args.INITIATE_LR_DECAY = 0.5 * args.EPOCHS # Set start of learning rate decay to 0 +args.NO_NOISE = args.EPOCHS # Set when discriminator noise decays to 0 # Image parameters args.CHANNELS = 1 args.DIMENSIONS = 3 -args.RAW_IMG_SIZE = (512, 512, 140, args.CHANNELS) # Unprocessed imaging domain image dimensions -args.TARG_RAW_IMG_SIZE = (512, 512, 128, args.CHANNELS) # Target size if downsampling -args.SYNTH_IMG_SIZE = (512, 512, 128) # Unprocessed segmentation domain image dimensions -args.TARG_SYNTH_IMG_SIZE = (512, 512, 128) # Target size if downsampling -args.SUBVOL_PATCH_SIZE = (128, 128, 128) # Size of subvolume to be trained on +args.RAW_IMG_SIZE = (512, 512, 140, args.CHANNELS) # Unprocessed imaging domain image dimensions +args.TARG_RAW_IMG_SIZE = (512, 512, 128, args.CHANNELS) # Target size if downsampling +args.SYNTH_IMG_SIZE = (512, 512, 128) # Unprocessed segmentation domain image dimensions +args.TARG_SYNTH_IMG_SIZE = (512, 512, 128) # Target size if downsampling +args.SUBVOL_PATCH_SIZE = (128, 128, 128) # Size of subvolume to be trained on # Set model input image size for training (based on above) if args.DIMENSIONS == 2: args.INPUT_IMG_SIZE = ( @@ -105,12 +100,11 @@ ) # Set callback parameters -args.PERIOD_2D_CALLBACK = 2 # Period of epochs to output a 2D validation dataset example -args.PERIOD_3D_CALLBACK = 2 # Period of epochs to output a 3D validation dataset example - +args.PERIOD_2D_CALLBACK = 2 # Period of epochs to output a 2D validation dataset example +args.PERIOD_3D_CALLBACK = 2 # Period of epochs to output a 3D validation dataset example -'''' PREPROCESSING ''' -imaging_data = DataPrepocessor(args, +'''' PREPROCESSING ''' +imaging_data = DataPreprocessor(args, raw_path='/mnt/sdb/3DcycleGAN_simLNet_LNet/raw_data/simLNet', main_dir='/mnt/sdb/3DcycleGAN_simLNet_LNet/', partition_id='A', @@ -118,13 +112,14 @@ tiff_size=args.RAW_IMG_SIZE, target_size=args.TARG_RAW_IMG_SIZE) -synth_data = DataPrepocessor(args, - raw_path='/mnt/sdb/3DcycleGAN_simLNet_LNet/raw_data/LNet', - main_dir='/mnt/sdb/3DcycleGAN_simLNet_LNet/', - partition_id='B', - partition_filename='dataB_partition.pkl', - tiff_size=args.SYNTH_IMG_SIZE, - target_size=args.TARG_SYNTH_IMG_SIZE) +synth_data = DataPreprocessor(args, + raw_path='/mnt/sdb/3DcycleGAN_simLNet_LNet/raw_data/LNet', + main_dir='/mnt/sdb/3DcycleGAN_simLNet_LNet/', + partition_id='B', + partition_filename='dataB_partition.pkl', + tiff_size=args.SYNTH_IMG_SIZE, + target_size=args.TARG_SYNTH_IMG_SIZE) + # Function used for preprocessing imaging domain images # The following is used for preprocessing raster-scanning optoacoustic mesoscopic (RSOM) image volumes @@ -140,19 +135,20 @@ def preprocess_rsom_images(img, lower_thresh=0.05, upper_thresh=99.95): Returns: - np.ndarray: The preprocessed 3D numpy array. """ - + # Slice-wise Z-Score Normalisation for z in range(img.shape[2]): - img[...,z] = z_score_norm(img[...,z]) - + img[..., z] = z_score_norm(img[..., z]) + # Clipping of upper and lower percentiles lp = sp.scoreatpercentile(img, lower_thresh) up = sp.scoreatpercentile(img, upper_thresh) img[img < lp] = lp img[img > up] = up - + return img + # Perform any preprocessing of images if neccessary # imaging_data.preprocess(preprocess_fn=preprocess_rsom_images, # save_filtered=True, @@ -163,98 +159,95 @@ def preprocess_rsom_images(img, lower_thresh=0.05, upper_thresh=99.95): imaging_data.load_partition('/mnt/sdb/3DcycleGAN_simLNet_LNet/dataA_partition.pkl') synth_data.load_partition('/mnt/sdb/3DcycleGAN_simLNet_LNet/dataB_partition.pkl') - ''' GENERATE TENSORFLOW DATASETS ''' print('*** Generating datasets for model ***') -# Define function to preprocess imaging domain image on the fly (OTF) + + +# Define function to preprocess imaging domain image on the fly (otf) # Min/max normalisation and rescaling to [-1,1] shown here @tf.function -def process_imaging_OTF(image): +def process_imaging_otf(image): return rescale_arr_tf( - min_max_norm_tf(image) - ) + min_max_norm_tf(image) + ) -# Define dataset class -getDataset = dataset_gen(args = args, - imaging_domain_data = imaging_data.partition, - seg_domain_data = synth_data.partition, - strategy = strategy, - otf_imaging = process_imaging_OTF # Set to None if OTF processing not needed - ) +# Define dataset class +getDataset = DatasetGen(args=args, + imaging_domain_data=imaging_data.partition, + seg_domain_data=synth_data.partition, + strategy=strategy, + otf_imaging=process_imaging_otf # Set to None if OTF processing not needed + ) ''' CALCULATE NUMBER OF TRAINING / VALIDATION STEPS ''' -args.train_steps = int(np.amax([len(imaging_data.partition['training']), +args.train_steps = int(np.amax([len(imaging_data.partition['training']), len(synth_data.partition['training'])]) / args.GLOBAL_BATCH_SIZE) args.val_steps = int(np.amax([len(imaging_data.partition['validation']), len(synth_data.partition['validation'])]) / args.GLOBAL_BATCH_SIZE) - ''' DEFINE VANGAN ''' -vangan_model = VanGan(args, - strategy = strategy, - genAB_typ = 'resUnet', - genBA_typ = 'resUnet' - ) - +vangan_model = VanGan(args, + strategy=strategy, + gen_i2s='resUnet', + gen_s2i='resUnet' + ) ''' DEFINE CUSTOM CALLBACK ''' -plotter = GAN_Monitor(args, - dataset = getDataset, - Alist = imaging_data.partition['validation'], - Blist = synth_data.partition['validation'], - process_imaging_domain = process_imaging_OTF - ) - +plotter = GanMonitor(args, + dataset=getDataset, + Alist=imaging_data.partition['validation'], + Blist=synth_data.partition['validation'], + process_imaging_domain=process_imaging_otf + ) ''' TRAIN VAN-GAN MODEL ''' for epoch in range(args.EPOCHS): print(f'\nEpoch {epoch + 1:03d}/{args.EPOCHS:03d}') start = time() - + vangan_model.current_epoch = epoch plotter.on_epoch_start(vangan_model, epoch, args) - + 'Training GAN for fixed no. of steps' - results = train(args, getDataset.train_ds, vangan_model, summary, epoch, args.train_steps, 'Train') + results = train(getDataset.train_dataset, vangan_model, summary, epoch, args.train_steps, 'Train') summary.losses(results) - + 'Run GAN for validation dataset' - results = train(args, getDataset.val_ds, vangan_model, summary, epoch, args.val_steps, 'Validate', training=False) + results = train(getDataset.val_dataset, vangan_model, summary, epoch, args.val_steps, 'Validate', + training=False) summary.losses(results) - - - if (epoch) % args.PERIOD_2D_CALLBACK == 1 or epoch == args.EPOCHS - 1: - plotter.on_epoch_end(vangan_model, epoch, args) + + if epoch % args.PERIOD_2D_CALLBACK == 1 or epoch == args.EPOCHS - 1: + plotter.on_epoch_end(vangan_model, epoch, args) vangan_model.save_checkpoint(epoch=epoch) - + end = time() summary.scalar('elapse', end - start, epoch=epoch, training=True) - ''' CREATE VANGAN PREDICTIONS ''' # Predict segmentation probability maps for imaging test dataset -plotter.run_mapping(vangan_model, imaging_data.partition['testing'], args.INPUT_IMG_SIZE, filetext='VANGAN_', filepath=args.output_dir, segmentation=True, stride=(25,25,25)) +plotter.run_mapping(vangan_model, imaging_data.partition['testing'], args.INPUT_IMG_SIZE, filetext='VANGAN_', + filepath=args.output_dir, segmentation=True, stride=(25, 25, 25)) # Prediction fake imaging data using synthetic segmentation test dataset -plotter.run_mapping(vangan_model, synth_data.partition['testing'], args.INPUT_IMG_SIZE, filetext='VANGAN_', filepath=args.output_dir, segmentation=False, stride=(25,25,25)) - +plotter.run_mapping(vangan_model, synth_data.partition['testing'], args.INPUT_IMG_SIZE, filetext='VANGAN_', + filepath=args.output_dir, segmentation=False, stride=(25, 25, 25)) ''' TESTING PREDICTIONS ACROSS EPOCHS ''' -epoch_sweep(args, - vangan_model, - plotter, - test_path='/PATH/TO/TEST/DATA/', # Can use imaging_data.partition['testing'] - start=100, - end=200, - segmentation=True # Set to False if fake imaging is wanted +epoch_sweep(args, + vangan_model, + plotter, + test_path='/PATH/TO/TEST/DATA/', # Can use imaging_data.partition['testing'] + start=100, + end=200, + segmentation=True # Set to False if fake imaging is wanted ) - ''' SEGMENTING NEW IMAGES ''' # Alternatively, to run VANGAN on a directory of images (saved as .npy) using the following example script -new_imaging_data = DataPrepocessor() # Create data preprocessor -new_imaging_data.process_new_data(current_path='/PATH/TO/DATA/', +new_imaging_data = DataPreprocessor() # Create data preprocessor +new_imaging_data.process_new_data(current_path='/PATH/TO/DATA/', new_path='/PATH/TO/SAVE/DATA/', preprocess_fn=preprocess_rsom_images, tiff_size=args.RAW_IMG_SIZE, @@ -265,4 +258,5 @@ def process_imaging_OTF(image): img_files = os.listdir(filepath) for file in img_files: img_files[file] = os.path.join(filepath, file) -plotter.run_mapping(vangan_model, img_files, args.INPUT_IMG_SIZE, filetext='VANGAN_', filepath=args.output_dir, segmentation=True, stride=(25,25,25)) +plotter.run_mapping(vangan_model, img_files, args.INPUT_IMG_SIZE, filetext='VANGAN_', filepath=args.output_dir, + segmentation=True, stride=(25, 25, 25)) diff --git a/post_training.py b/post_training.py index 0ceeae3..f9302f9 100644 --- a/post_training.py +++ b/post_training.py @@ -1,5 +1,6 @@ import os + def epoch_sweep(args, vangan_model, plotter, test_path='', start=100, end=200, step=2, segmentation=True): """ Perform a sweep of epochs for the given VANGAN model and save the resulting images using the given plotter. @@ -18,22 +19,21 @@ def epoch_sweep(args, vangan_model, plotter, test_path='', start=100, end=200, s - None """ - test_path = '/mnt/sda/VS-GAN_deepVess/testA/' - - for i in range(start,end+1,step): + for i in range(start, end + 1, step): print(i) vangan_model.load_checkpoint(epoch=i, - newpath='/mnt/sdb/TPLSM/Boas_DeepVess_Image_Standardisation/VG_Output/checkpoints') - + newpath=args.output_dir+'/checkpoints') + # Make epoch folders - filepath = '/mnt/sdb/TPLSM/Boas_DeepVess_Image_Standardisation/Epoch_Sampling/' + filepath = args.output_dir+'/Epoch_Sampling/' folder = os.path.join(filepath, 'e{idx}'.format(idx=i)) if not os.path.isdir(folder): os.makedirs(folder) - + testfiles = os.listdir(test_path) filename = 'e{idx}_VG_'.format(idx=i) for file in testfiles: - testfiles[file] = os.path.join(test_path,file) - - plotter.run_mapping(vangan_model, testfiles, args.INPUT_IMG_SIZE, filetext=filename, segmentation=segmentation, stride=(50,50,50), filepath=folder, padFactor=0.1) + testfiles[file] = os.path.join(test_path, file) + + plotter.run_mapping(vangan_model, testfiles, args.INPUT_IMG_SIZE, filetext=filename, segmentation=segmentation, + stride=(50, 50, 50), filepath=folder, padFactor=0.1) diff --git a/preprocessing.py b/preprocessing.py index c378c13..4d6e976 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -8,12 +8,17 @@ from joblib import Parallel, delayed import multiprocessing -from utils import min_max_norm, check_nan, save_dict, load_dict, get_vaccuum, resize_volume +from utils import min_max_norm, check_nan, save_dict, load_dict, resize_volume -class DataPrepocessor: - def __init__(self, args=None, raw_path=None, main_dir=None, partition_id='', partition_filename=None, tiff_size=(600, 600, 700), + +class DataPreprocessor: + def __init__(self, args=None, raw_path=None, main_dir=None, partition_id='', partition_filename=None, + tiff_size=(600, 600, 700), target_size=(600, 600, 700), num_cores=multiprocessing.cpu_count() - 1): + self.save_filtered = None + self.resize = None + self.preprocess_fn = None self.raw_path = raw_path self.main_dir = main_dir self.partition_id = partition_id @@ -24,12 +29,12 @@ def __init__(self, args=None, raw_path=None, main_dir=None, partition_id='', par self.validate_files = None self.test_files = None self.partition = {} - + self.NUM_CORES = int(0.8 * num_cores) if args is not None: self.DIMENSIONS = args.DIMENSIONS self.CHANNELS = args.CHANNELS - + def save_partition(self, save_path=None): """ Save the partition data into files in the specified directory. @@ -40,151 +45,140 @@ def save_partition(self, save_path=None): Returns: None """ - + if save_path is None: raise ValueError("Partition save_path is not provided.") - + # Update partition directories new_partition = {} train_arr = np.empty(len(self.partition['training']), dtype=object) val_arr = np.empty(len(self.partition['validation']), dtype=object) test_arr = np.empty(len(self.partition['testing']), dtype=object) - + # Update the training partition directory for i in range(len(self.partition['training'])): file = self.partition['training'][i] file, _ = os.path.splitext(file) file = file + '.npy' - file = os.path.join(save_path, 'train'+self.partition_id, file) + file = os.path.join(save_path, 'train' + self.partition_id, file) train_arr[i] = file - + # Update the validation partition directory for i in range(len(self.partition['validation'])): file = self.partition['validation'][i] file, _ = os.path.splitext(file) file = file + '.npy' - file = os.path.join(save_path, 'val'+self.partition_id, file) + file = os.path.join(save_path, 'val' + self.partition_id, file) val_arr[i] = file - + # Update the testing partition directory for i in range(len(self.partition['testing'])): file = self.partition['testing'][i] file, _ = os.path.splitext(file) file = file + '.npy' - file = os.path.join(save_path, 'test'+self.partition_id, file) + file = os.path.join(save_path, 'test' + self.partition_id, file) test_arr[i] = file - + new_partition['training'] = train_arr new_partition['validation'] = val_arr new_partition['testing'] = test_arr - + save_dict(new_partition, os.path.join(save_path, self.partition_filename)) - + self.partition = new_partition - + def load_partition(self, file_path): - print('*** Loading Dataset %s Partition ***' %(self.partition_id)) + print('*** Loading Dataset %s Partition ***' % self.partition_id) self.partition = load_dict(file_path) - + def split_dataset(self): - + # Shuffle raw data list files = os.listdir(self.raw_path) random.shuffle(files) - + # Split data into train/validate/test print('Splitting dataset ...') - self.train_files, self.test_files = np.split(files, [int(len(files)*0.9)]) - self.train_files, self.validate_files = np.split(files, [int(len(files)*0.8)]) - + self.train_files, self.test_files = np.split(files, [int(len(files) * 0.9)]) + self.train_files, self.validate_files = np.split(files, [int(len(files) * 0.8)]) + # Save partitioned dataset self.partition['training'] = self.train_files self.partition['validation'] = self.validate_files self.partition['testing'] = self.test_files - - + def move_dataset(self): for file in range(len(self.partition['training'])): shutil.move(os.path.join(self.raw_path, self.partition['training'][file]), - os.path.join(self.main_dir, 'train'+self.partition_id)) + os.path.join(self.main_dir, 'train' + self.partition_id)) for file in range(len(self.partition['validation'])): shutil.move(os.path.join(self.raw_path, self.partition['validation'][file]), - os.path.join(self.main_dir, 'val'+self.partition_id)) + os.path.join(self.main_dir, 'val' + self.partition_id)) for file in range(len(self.partition['testing'])): shutil.move(os.path.join(self.raw_path, self.partition['testing'][file]), - os.path.join(self.main_dir, 'test'+self.partition_id)) - + os.path.join(self.main_dir, 'test' + self.partition_id)) + def preprocess(self, preprocess_fn=None, resize=False, save_filtered=False): - - print('*** Preprocessing partition %s images ***' %(self.partition_id)) + + print('*** Preprocessing partition %s images ***' % self.partition_id) self.split_dataset() - + self.preprocess_fn = preprocess_fn self.resize = resize self.save_filtered = save_filtered - + print('Processing training data ...') Parallel(n_jobs=self.NUM_CORES, verbose=50)(delayed( - self.process_tiff)(file=self.partition['training'][file], + self.process_tiff)(file=self.partition['training'][file], label='train') for file in range(len(self.partition['training']))) - + print('Processing validation data ...') Parallel(n_jobs=self.NUM_CORES, verbose=50)(delayed( - self.process_tiff)(file=self.partition['validation'][file], + self.process_tiff)(file=self.partition['validation'][file], label='val') for file in range(len(self.partition['validation']))) - + print('Processing testing data ...') Parallel(n_jobs=self.NUM_CORES, verbose=50)(delayed( - self.process_tiff)(file=self.partition['testing'][file], + self.process_tiff)(file=self.partition['testing'][file], label='test') for file in range(len(self.partition['testing']))) - + self.save_partition(self.main_dir) - + def process_tiff(self, file, label=''): - + """ Process a TIFF image file. Args: - args (argparse.Namespace): The command line arguments file (str): The name of the file to be processed - raw_path (str): The path to the directory containing the raw images - main_dir (str): The path to the directory to save the processed images - tiff_size (int or tuple): The size of the TIFF image - target_size (int or tuple): The desired size of the processed image label (str): The label to be appended to the processed image - partition_id (str): The partition ID to be appended to the processed image - resize (bool): Whether to resize the image or not - preprocess_fn (function): The function to be used for preprocessing the image - save_filtered (bool): Whether to save the filtered image or not Returns: None """ - stack = (sk.imread(os.path.join(self.raw_path, file))).astype('float32') - + file, ext = os.path.splitext(file) # if partition_id == 'A': if self.DIMENSIONS == 3: stack = np.transpose(stack, (1, 2, 0)) - + # if self.partition_id == 'B': - # stack = get_vaccuum(stack, self.DIMENSIONS) # Reduce bounding box to tree size - + # stack = get_vacuum(stack, self.DIMENSIONS) # Reduce bounding box to tree size + if self.preprocess_fn is not None: stack = self.preprocess_fn(stack) - + if not self.tiff_size == self.target_size and self.resize: stack = (resize_volume(stack, self.target_size)).astype('float32') if self.partition_id == 'B': stack[stack < 0.] = 0.0 stack[stack > 255.] = 255 - + stack = min_max_norm(stack) if self.partition_id == 'B': - mode, _ = stats.mode(stack, axis=None) + mode, _ = stats.mode(stack, axis=None) if mode == 1: stack -= 1. stack = abs(stack) @@ -195,30 +189,34 @@ def process_tiff(self, file, label=''): stack[stack >= 0.] = 1.0 if not check_nan(stack): - + if self.save_filtered: - arr_out = os.path.join(os.path.join(self.main_dir,'filtered'), - label+self.partition_id, file+'.tiff') + arr_out = os.path.join(os.path.join(self.main_dir, 'filtered'), + label + self.partition_id, file + '.tiff') if ext == '.npy': - sk.imsave(arr_out, (stack* 127.5 + 127.5).astype('uint8'), bigtiff=False, check_contrast=False) + sk.imsave(arr_out, (stack * 127.5 + 127.5).astype('uint8'), bigtiff=False, check_contrast=False) else: if self.DIMENSIONS == 2: sk.imsave(arr_out, (stack * 127.5 + 127.5).astype('uint8'), bigtiff=False, check_contrast=False) else: - sk.imsave(arr_out, (np.transpose(stack,(2,1,0)) * 127.5 + 127.5).astype('uint8'), bigtiff=False, check_contrast=False) - + sk.imsave(arr_out, (np.transpose(stack, (2, 1, 0)) * 127.5 + 127.5).astype('uint8'), + bigtiff=False, check_contrast=False) + if self.partition_id == 'B': - np.save(os.path.join(self.main_dir, label+self.partition_id, file), np.expand_dims(stack, axis=self.DIMENSIONS)) + np.save(os.path.join(self.main_dir, label + self.partition_id, file), + np.expand_dims(stack, axis=self.DIMENSIONS)) else: if self.DIMENSIONS == 2 and self.CHANNELS == 3: - np.save(os.path.join(self.main_dir, label+self.partition_id, file), stack) + np.save(os.path.join(self.main_dir, label + self.partition_id, file), stack) else: - np.save(os.path.join(self.main_dir, label+self.partition_id, file), np.expand_dims(stack, axis=self.DIMENSIONS)) + np.save(os.path.join(self.main_dir, label + self.partition_id, file), + np.expand_dims(stack, axis=self.DIMENSIONS)) else: print('NaN detected ...') - - def process_new_data(self, current_path, new_path, tiff_size=None, target_size=None, preprocess_fn=None, resize=None): - + + def process_new_data(self, current_path, new_path, tiff_size=None, target_size=None, preprocess_fn=None, + resize=None): + self.raw_path = current_path self.main_dir = new_path self.tiff_size = tiff_size @@ -226,9 +224,7 @@ def process_new_data(self, current_path, new_path, tiff_size=None, target_size=N self.preprocess_fn = preprocess_fn self.resize = resize self.save_filtered = False - + files = os.listdir(current_path) for file in files: - self.process_tiff(file = file) - - + self.process_tiff(file=file) diff --git a/resunet_model.py b/resunet_model.py index 4322d13..92dfc09 100644 --- a/resunet_model.py +++ b/resunet_model.py @@ -14,37 +14,37 @@ Activation, Add, GaussianNoise - ) +) import tensorflow_addons as tfa from building_blocks import ReflectionPadding3D from vnet_model import attention_concat -def norm_act(x, + +def norm_act(x, act=True): """ - Apply instance normalization and activation function (ReLU by default) to input tensor. + Apply instance normalisation and activation function (ReLU by default) to input tensor. Args: x (tensor): Input tensor. act (bool): Whether to apply activation function. Default is True. Returns: - tensor: Output tensor after instance normalization and activation (if applicable). + tensor: Output tensor after instance normalisation and activation (if applicable). """ x = tfa.layers.InstanceNormalization()(x) - if act == True: + if act: x = Activation("relu")(x) return x -def conv_block(x, - filters, - kernel_size=(3, 3, 3), - padding="valid", - strides=1, - kernel_initializer=None, - dropout_type=None, - dropout=None): + +def conv_block(x, + filters, + kernel_size=(3, 3, 3), + padding="valid", + strides=1, + kernel_initializer='he_normal'): """ A convolutional block that consists of a normalization and activation layer followed by a convolutional layer. @@ -55,8 +55,6 @@ def conv_block(x, padding (str, optional): The type of padding to apply. Defaults to "valid". strides (int, optional): The stride of the convolution. Defaults to 1. kernel_initializer (str, optional): The name of the kernel initializer to use. Defaults to None. - dropout_type (str, optional): The type of dropout to apply. Defaults to None. - dropout (float, optional): The dropout rate to apply. Defaults to None. Returns: tensor: The output tensor after passing through the convolutional block. @@ -67,10 +65,11 @@ def conv_block(x, conv = Conv3D(filters, kernel_size, padding=padding, strides=strides, kernel_initializer=kernel_initializer)(conv) return conv -def stem(x, - filters, - kernel_size=(3, 3, 3), - padding="valid", + +def stem(x, + filters, + kernel_size=(3, 3, 3), + padding="valid", strides=1): """ The stem operation for the start of the deep residual UNet. @@ -89,24 +88,25 @@ def stem(x, if padding == 'valid': conv = ReflectionPadding3D()(x) conv = Conv3D(filters, kernel_size, padding=padding, strides=strides)(conv) - else: + else: conv = Conv3D(filters, kernel_size, padding=padding, strides=strides)(x) conv = conv_block(conv, filters, kernel_size=kernel_size, padding=padding, strides=strides) - + # Identity mapping shortcut = Conv3D(filters, kernel_size=(1, 1, 1), padding="same", strides=strides)(x) shortcut = norm_act(shortcut, act=False) - + output = Add()([conv, shortcut]) return output -def residual_block(x, - filters, - kernel_size=(3, 3, 3), - padding="valid", - strides=1, - kernel_initializer=None, - dropout_type=None, + +def residual_block(x, + filters, + kernel_size=(3, 3, 3), + padding="valid", + strides=1, + kernel_initializer='he_normal', + dropout_type=None, dropout=None): """ Constructs a residual block of the 3D residual UNet architecture. @@ -124,83 +124,126 @@ def residual_block(x, Returns: tensor: The output tensor of the residual block. """ - res = conv_block(x, filters, kernel_size=kernel_size, padding=padding, strides=strides, kernel_initializer=kernel_initializer, dropout_type=dropout_type, dropout=dropout) - res = conv_block(res, filters, kernel_size=kernel_size, padding=padding, strides=1, kernel_initializer=kernel_initializer, dropout_type=dropout_type, dropout=dropout) - + res = conv_block(x, filters, kernel_size=kernel_size, padding=padding, strides=strides, + kernel_initializer=kernel_initializer) + res = conv_block(res, filters, kernel_size=kernel_size, padding=padding, strides=1, + kernel_initializer=kernel_initializer) + # Identity mapping - shortcut = Conv3D(filters, kernel_size=(1, 1, 1), padding="same", strides=strides, kernel_initializer=kernel_initializer)(x) + shortcut = Conv3D(filters, kernel_size=(1, 1, 1), padding="same", strides=strides, + kernel_initializer=kernel_initializer)(x) shortcut = norm_act(shortcut, act=False) - - output = Add()([shortcut, res]) + + output = Add()([shortcut, res]) if dropout_type == 'spatial': output = SpatialDropout3D(dropout)(output) elif dropout_type == 'standard': output = Dropout(dropout)(output) - + return output -def upsample_concat_block(x, xskip, filters, kernel_initializer=None, upsample_mode='deconv', padding='valid', use_attention_gate=False): +def upsample_concat_block(x, + xskip, + filters, + kernel_initializer='he_normal', + upsample_mode='deconv', + padding='valid', + use_attention_gate=False): + """ + Create an upsample and concatenate block for U-Net-like architectures. + + Args: + x (tf.Tensor): Input tensor to be upsampled and concatenated. + xskip (tf.Tensor): Skip connection tensor to be concatenated with the upsampled tensor. + filters (int): Number of filters in the convolutional layer. + kernel_initializer: Initializer for the convolutional kernel weights. + upsample_mode (str): The mode for upsampling, either 'deconv' (deconvolution) or 'upsample' (using UpSampling3D). + padding (str): Padding mode for upsampling, either 'valid' or 'same'. + use_attention_gate (bool): Whether to use an attention gate before concatenation. + + Returns: + (tf.Tensor): Output tensor after performing upsample and concatenate operations. + """ if upsample_mode == 'deconv': if padding == 'valid': x = ReflectionPadding3D()(x) - x = Conv3DTranspose(filters, (2, 2, 2), strides=(2, 2, 2), padding='valid')(x) + x = Conv3DTranspose(filters, (2, 2, 2), + strides=(2, 2, 2), + padding='valid', + kernel_initializer=kernel_initializer)(x) else: x = UpSampling3D(size=2)(x) - #x = Conv3D(int(filters/2), (5, 5, 5), strides=(1, 1, 1), padding='valid')(x) + # x = Conv3D(int(filters/2), (5, 5, 5), strides=(1, 1, 1), padding='valid')(x) if use_attention_gate: x = attention_concat(conv_below=x, skip_connection=xskip) else: x = concatenate([x, xskip]) return x + def ResUNet( input_shape, - num_classes=1, - activation='relu', - use_batch_norm=True, upsample_mode='deconv', # 'deconv' or 'simple' dropout=0.2, dropout_change_per_layer=0.0, dropout_type='none', - use_dropout_on_upsampling=False, kernel_initializer='he_normal', - gamma_initializer='he_normal', use_attention_gate=False, filters=16, num_layers=4, output_activation='tanh', use_input_noise=False - ): - - f = [filters, filters*2, filters*4, filters*8, filters*16] +): + """ + Create a Residual U-Net model for 3D image segmentation. + + Args: + input_shape (tuple): Shape of the input image tensor. + upsample_mode (str): The mode for upsampling, either 'deconv' (deconvolution) or 'simple' (using UpSampling3D). + dropout (float): Dropout rate for the initial layer. + dropout_change_per_layer (float): Dropout rate change per layer (optional). + dropout_type (str): Type of dropout, either 'none', 'spatial', or 'channel'. + kernel_initializer (str): Initializer for the convolutional kernel weights. + use_attention_gate (bool): Whether to use an attention gate for concatenation. + filters (int): Number of filters in the convolutional layers. + num_layers (int): Number of layers in the U-Net model. + output_activation (str): Activation function for the output layer. + use_input_noise (bool): Whether to apply Gaussian noise to the input. + + Returns: + (tf.keras.Model): The Residual U-Net model for 3D image segmentation. + """ + f = [filters, filters * 2, filters * 4, filters * 8, filters * 16] inputs = Input(input_shape) skip_layers = [] - + x = inputs - + if use_input_noise: x = GaussianNoise(0.2)(x) - - e = stem(x, f[0]) + + x = stem(x, f[0]) skip_layers.append(x) - + # Encoder - for e in range(1,num_layers+1): - x = residual_block(x, f[e], strides=2, kernel_initializer=kernel_initializer, dropout_type=dropout_type, dropout=dropout+(e-1)*dropout_change_per_layer) + for e in range(1, num_layers + 1): + x = residual_block(x, f[e], strides=2, kernel_initializer=kernel_initializer, dropout_type=dropout_type, + dropout=dropout + (e - 1) * dropout_change_per_layer) skip_layers.append(x) - + # Bridge x = conv_block(x, f[-1], strides=1, kernel_initializer=kernel_initializer) # #d = spatial_attention(d) x = conv_block(x, f[-1], strides=1, kernel_initializer=kernel_initializer) - + for d in reversed(range(num_layers)): - x = upsample_concat_block(x, skip_layers[d], f[d+1], kernel_initializer=kernel_initializer, upsample_mode=upsample_mode, use_attention_gate=use_attention_gate) + x = upsample_concat_block(x, skip_layers[d], f[d + 1], kernel_initializer=kernel_initializer, + upsample_mode=upsample_mode, use_attention_gate=use_attention_gate) x = residual_block(x, f[d], kernel_initializer=kernel_initializer) - + outputs = Conv3D(1, (1, 1, 1), padding="same", activation=output_activation)(x) - + model = Model(inputs, outputs) model.summary() - return model \ No newline at end of file + return model diff --git a/tb_callback.py b/tb_callback.py index 0408c56..75ffe9a 100644 --- a/tb_callback.py +++ b/tb_callback.py @@ -2,22 +2,25 @@ import io import platform import matplotlib + if platform.system() == 'Darwin': - matplotlib.use('TkAgg') + matplotlib.use('TkAgg') import matplotlib.pyplot as plt import tensorflow as tf import numpy as np import typing as t -class TB_Summary(): + +class TB_Summary: """ Helper class to write TensorBoard summaries """ + def __init__(self, output_dir: str): self.dpi = 120 plt.style.use('seaborn-deep') - + self.train_summary_writer = tf.summary.create_file_writer(os.path.join(output_dir, 'train')) self.validate_summary_writer = tf.summary.create_file_writer(os.path.join(output_dir, 'validate')) - + def scalar(self, tag, value, epoch, training): if training: with self.train_summary_writer.as_default(): @@ -25,18 +28,18 @@ def scalar(self, tag, value, epoch, training): else: with self.validate_summary_writer.as_default(): tf.summary.scalar(tag, value, step=epoch) - + def losses(self, results): for key, value in results.items(): value = tf.math.reduce_mean(value) - print('%s = %.4f ' %(key, value.numpy()), end='') + print('%s = %.4f ' % (key, value.numpy()), end='') print('\n') def image(self, tag, values, step: int = 0, training: bool = False): writer = self.get_writer(training) with writer.as_default(): - tf.summary.image(tag, data=values, step=step, max_outputs=len(values)) - + tf.summary.image(tag, data=values, step=step, max_outputs=len(values)) + def figure(self, tag, figure, @@ -58,8 +61,8 @@ def figure(self, image = tf.image.decode_png(buffer.getvalue(), channels=4) self.image(tag, tf.expand_dims(image, 0), step=step, training=training) if close: - plt.close(figure) - + plt.close(figure) + def image_cycle(self, tag: str, images: t.List[np.ndarray], @@ -68,7 +71,7 @@ def image_cycle(self, training: bool = False): """ Plot image cycle to TensorBoard Args: - tags: data identifier + tag: data identifier images: list of np.ndarray where len(images) == 3 and each array has shape (N,H,W,C) labels: list of string where len(labels) == 3 @@ -77,24 +80,24 @@ def image_cycle(self, """ assert len(images) == len(labels) == 3 for sample in range(len(images[0])): - figure, axes = plt.subplots(nrows=1, - ncols=3, - figsize=(9, 3.25), - dpi=self.dpi) - axes[0].imshow(images[0][sample, ...], interpolation='none') - axes[0].set_title(labels[0]) - - axes[1].imshow(images[1][sample, ...], interpolation='none') - axes[1].set_title(labels[1]) - - axes[2].imshow(images[2][sample, ...], interpolation='none') - axes[2].set_title(labels[2]) - - plt.setp(axes, xticks=[], yticks=[]) - plt.tight_layout() - figure.subplots_adjust(wspace=0.02, hspace=0.02) - self.figure(tag=f'{tag}/sample_#{sample:03d}', - figure=figure, - step=step, - training=training, - close=True) \ No newline at end of file + figure, axes = plt.subplots(nrows=1, + ncols=3, + figsize=(9, 3.25), + dpi=self.dpi) + axes[0].imshow(images[0][sample, ...], interpolation='none') + axes[0].set_title(labels[0]) + + axes[1].imshow(images[1][sample, ...], interpolation='none') + axes[1].set_title(labels[1]) + + axes[2].imshow(images[2][sample, ...], interpolation='none') + axes[2].set_title(labels[2]) + + plt.setp(axes, xticks=[], yticks=[]) + plt.tight_layout() + figure.subplots_adjust(wspace=0.02, hspace=0.02) + self.figure(tag=f'{tag}/sample_#{sample:03d}', + figure=figure, + step=step, + training=training, + close=True) diff --git a/utils.py b/utils.py index 69195a8..88698bb 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ import skimage.io as sk from skimage import exposure + def min_max_norm(data): """ Perform min-max normalisation on a N-dimensional numpy array. @@ -21,7 +22,8 @@ def min_max_norm(data): raise ValueError("Cannot perform min-max normalization when max and min are equal.") return (data - dmin) / (dmax - dmin) -def min_max_norm_tf(arr, axis = None): + +def min_max_norm_tf(arr, axis=None): """ Performs min-max normalization on a given array using TensorFlow library. @@ -32,7 +34,7 @@ def min_max_norm_tf(arr, axis = None): Returns: - tensor: A normalized tensor with the same shape as the input array. """ - + if axis is None: # Normalize entire array min_val = tf.reduce_min(arr) @@ -43,10 +45,11 @@ def min_max_norm_tf(arr, axis = None): min_val = tf.reduce_min(arr, axis=axis, keepdims=True) max_val = tf.reduce_max(arr, axis=axis, keepdims=True) tensor = (arr - min_val) / (max_val - min_val) - + return tensor -def rescale_arr_tf(arr, alpha = -0.5, beta = 0.5): + +def rescale_arr_tf(arr, alpha=-0.5, beta=0.5): """ Rescales the values in a tensor using the alpha and beta parameters. alpha = -0.5, beta = 0.5: [0,1] to [-1,1] @@ -62,6 +65,7 @@ def rescale_arr_tf(arr, alpha = -0.5, beta = 0.5): """ return tf.math.divide_no_nan((arr + alpha), beta) + def z_score_norm(data): """ Perform z-score normalisation on a one-dimensional numpy array. @@ -77,7 +81,7 @@ def z_score_norm(data): return (data - np.mean(data)) / dstd else: raise ValueError("Cannot perform z-score normalization when the standard deviation is zero.") - + def check_nan(arr): """ @@ -91,6 +95,7 @@ def check_nan(arr): """ return np.any(np.isnan(arr)) + def replace_nan(arr): """ Replace NaN (Not a Number) values in a NumPy array with zeros. @@ -103,7 +108,35 @@ def replace_nan(arr): """ return tf.where(tf.math.is_nan(arr), tf.zeros_like(arr), arr) -def load_volume(file, size=(600,600,700), datatype='uint8', normalise=True): + +def binarise_tensor(arr): + """ + Binarise a TensorFlow tensor by replacing positive values with ones and non-positive values with negative ones. + + Args: + arr (tf.Tensor): Input TensorFlow tensor to be binarised. + + Returns: + (tf.Tensor): Binarized TensorFlow tensor with ones for positive values and negative ones for non-positive values. + """ + return tf.where(tf.math.greater_equal(arr, tf.zeros(tf.shape(arr))), + tf.ones(tf.shape(arr)), + tf.math.negative(tf.ones(tf.shape(arr)))) + +def add_gauss_noise(self, img, rate): + """ + Add Gaussian noise to a TensorFlow image tensor. + + Args: + img (tf.Tensor): Input TensorFlow image tensor to which noise will be added. + rate (float): Standard deviation of the Gaussian noise. + + Returns: + (tf.Tensor): TensorFlow image tensor with added Gaussian noise and values clipped between -1.0 and 1.0. + """ + return tf.clip_by_value(img + tf.random.normal(tf.shape(img), 0.0, rate), -1., 1.) + +def load_volume(file, size=(600, 600, 700), datatype='uint8', normalise=True): """ Load a volume from a (for example) tif file and normalise it. @@ -122,6 +155,7 @@ def load_volume(file, size=(600,600,700), datatype='uint8', normalise=True): vol = min_max_norm(vol) return vol + def resize_volume(img, target_size=None): """ Resize a 3D volume to a target size. @@ -133,28 +167,29 @@ def resize_volume(img, target_size=None): Returns: numpy.ndarray: The resized 3D volume. """ - + # Create two arrays to hold intermediate and final results arr1 = np.empty([target_size[0], target_size[1], img.shape[2]], dtype='float32') arr2 = np.empty([target_size[0], target_size[1], target_size[2]], dtype='float32') - + # If the input volume's width and height don't match the target size, resize each slice along the z-axis if not img.shape[0:2] == target_size[0:2]: for i in range(img.shape[2]): - arr1[:,:,i] = cv2.resize(img[:,:,i], (target_size[0], target_size[1]), - interpolation=cv2.INTER_LANCZOS4) - + arr1[:, :, i] = cv2.resize(img[:, :, i], (target_size[0], target_size[1]), + interpolation=cv2.INTER_LANCZOS4) + for i in range(target_size[0]): - arr2[i,:,:] = cv2.resize(arr1[i,], (target_size[2], target_size[1]), - interpolation=cv2.INTER_LANCZOS4) - - else: # If the input volume's width and height match the target size, resize each slice along the x-axis + arr2[i, :, :] = cv2.resize(arr1[i,], (target_size[2], target_size[1]), + interpolation=cv2.INTER_LANCZOS4) + + else: # If the input volume's width and height match the target size, resize each slice along the x-axis for i in range(target_size[0]): - arr2[i,:,:] = cv2.resize(img[i,], (target_size[2], target_size[1]), - interpolation=cv2.INTER_LANCZOS4) - + arr2[i, :, :] = cv2.resize(img[i,], (target_size[2], target_size[1]), + interpolation=cv2.INTER_LANCZOS4) + return arr2 + def get_vaccuum(arr, dim): """ Returns the smallest subarray containing all non-zero elements in the input array along the specified dimension(s). @@ -168,11 +203,12 @@ def get_vaccuum(arr, dim): """ if dim == 2: x, y = np.nonzero(arr) - return arr[x.min():x.max()+1, y.min():y.max()+1] + return arr[x.min():x.max() + 1, y.min():y.max() + 1] else: x, y, z = np.nonzero(arr) - return arr[x.min():x.max()+1, y.min():y.max()+1, z.min():z.max()+1] - + return arr[x.min():x.max() + 1, y.min():y.max() + 1, z.min():z.max() + 1] + + def hist_equalization(img): """ Applies histogram equalization to the input image. @@ -185,7 +221,8 @@ def hist_equalization(img): """ img_cdf, bin_centers = exposure.cumulative_distribution(img) return np.interp(img, bin_centers, img_cdf) - + + def save_dict(di_, filename_): """Saves a Python dictionary object to a file using the pickle module. @@ -199,6 +236,7 @@ def save_dict(di_, filename_): with open(filename_, 'wb') as f: pickle.dump(di_, f) + def load_dict(filename_): """ Load a dictionary from a binary file using the pickle module. @@ -212,7 +250,8 @@ def load_dict(filename_): ret_di = pickle.load(f) return ret_di -def append_dict(dict1, dict2, replace = False) -> dict: + +def append_dict(dict1, dict2, replace=False) -> dict: """ Append items in dict2 to dict1. @@ -259,10 +298,10 @@ def get_sub_volume(image, subvol=(64, 64, 512), n_samples=1): Returns: - subvol (numpy.ndarray): A numpy array of shape (subvol[0], subvol[1], subvol[2], subvol[3]) representing the sub-volume extracted from the input image tensor. """ - + # Initialize features and labels with `None` sample = np.empty([subvol[0], subvol[1], subvol[2], subvol[3]], dtype='float32') - + # randomly sample sub-volume by sampling the corner voxel start_x = np.random.randint(image.shape[0] - subvol[0] + 1) start_y = np.random.randint(image.shape[1] - subvol[1] + 1) @@ -270,11 +309,12 @@ def get_sub_volume(image, subvol=(64, 64, 512), n_samples=1): # make copy of the sub-volume sample = np.copy(image[start_x: start_x + subvol[0], - start_y: start_y + subvol[1], - start_z: start_z + subvol[2], :]) - + start_y: start_y + subvol[1], + start_z: start_z + subvol[2], :]) + return sample + def get_shape(arr): """ Get the shape of a nested list. @@ -300,4 +340,4 @@ def get_shape(arr): # @tf.function(experimental_compile=True) # Enable XLA # def fast_clahe(img): -# return tf_clahe.clahe(img, tile_grid_size=(4, 4), gpu_optimized=True) \ No newline at end of file +# return tf_clahe.clahe(img, tile_grid_size=(4, 4), gpu_optimized=True) diff --git a/vangan.py b/vangan.py index f3caaa4..099a5cd 100644 --- a/vangan.py +++ b/vangan.py @@ -4,30 +4,31 @@ from tqdm import tqdm from generator import get_resnet_generator from discriminator import get_discriminator -from loss_functions import (generator_loss_fn, - discriminator_loss_fn, - cycle_loss, - identity_loss, - cycle_seg_loss, - wasserstein_generator_loss, - wasserstein_discriminator_loss, +from loss_functions import (generator_loss_fn, + discriminator_loss_fn, + cycle_loss, + identity_loss, + cycle_seg_loss, + wasserstein_generator_loss, + wasserstein_discriminator_loss, reduce_mean, - cycle_perceptual) + cycle_reconstruction) from vnet_model import custom_vnet from resunet_model import ResUNet -class VanGan(): + +class VanGan: def __init__( - self, - args, - strategy, - lambda_cycle=10.0, - lambda_identity=5, - genAB_typ = 'resnet', - genBA_typ = 'resnet', - wasserstein = False, - ncritic = 5, - gp_weight=10.0 + self, + args, + strategy, + lambda_cycle=10.0, + lambda_identity=5, + gen_i2s='resnet', + gen_s2i='resnet', + wasserstein=False, + ncritic=5, + gp_weight=10.0 ): self.strategy = strategy self.n_devices = args.N_DEVICES @@ -35,24 +36,24 @@ def __init__( self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity, self.channels = args.CHANNELS - self.genAB_typ = genAB_typ - self.genBA_typ = genBA_typ + self.gen_i2s_typ = gen_i2s + self.gen_s2i_typ = gen_s2i self.global_batch_size = args.GLOBAL_BATCH_SIZE self.dims = args.DIMENSIONS if self.dims == 2: self.subvol_patch_size = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], self.channels) self.seg_subvol_patch_size = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], 1) else: - self.subvol_patch_size = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], self.channels) - self.seg_subvol_patch_size = (args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], 1) - self.gamma_init = args.GAMMA_INIT - self.kernel_init = args.KERNEL_INIT + self.subvol_patch_size = ( + args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], self.channels) + self.seg_subvol_patch_size = ( + args.SUBVOL_PATCH_SIZE[0], args.SUBVOL_PATCH_SIZE[1], args.SUBVOL_PATCH_SIZE[2], 1) self.train_steps = args.train_steps self.batch_size = args.BATCH_SIZE self.cycle_loss_fn = cycle_loss self.identity_loss_fn = identity_loss self.wasserstein = wasserstein - self.ncritic = ncritic + self.ncritic = ncritic self.icritic = 1 self.initModel = True self.updateGen = True @@ -63,232 +64,206 @@ def __init__( self.discriminator_loss_fn = discriminator_loss_fn self.identity_loss_fn = identity_loss self.seg_loss_fn = cycle_seg_loss - self.perceptual_loss = cycle_perceptual + self.reconstruction_loss = cycle_reconstruction self.decayed_noise_rate = 0.5 self.current_epoch = 0 self.layer_noise = 0.1 self.checkpoint_loaded = False - + # create checkpoint directory self.checkpoint_dir = os.path.join(args.output_dir, 'checkpoints') if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) self.checkpoint_prefix = os.path.join(self.checkpoint_dir, 'checkpoint') - + # Initialize generator & discriminator with self.strategy.scope(): - - if self.genAB_typ == 'resnet': - self.gen_AB = get_resnet_generator( - input_img_size=self.subvol_patch_size, - batch_size=self.global_batch_size, - # gamma_initializer=self.gamma_init, - # kernel_initializer=self.kernel_init, - name='generator_AB', - num_downsampling_blocks=3, - num_upsample_blocks=3 - ) - elif self.genAB_typ == 'vnet': - self.gen_AB = custom_vnet( - input_shape = self.subvol_patch_size, - num_classes=1, - activation='relu', - use_batch_norm=False, - upsample_mode='upsample', - dropout=0.5, - dropout_change_per_layer=0.0, - dropout_type='spatial', - use_dropout_on_upsampling=False, - kernel_initializer=self.kernel_init, - use_attention_gate=False, - filters=32, - num_layers=4, - output_activation='tanh', - ) - elif self.genAB_typ == 'resUnet': - self.gen_AB = ResUNet( - input_shape = self.subvol_patch_size, - num_classes=1, - activation='relu', - use_batch_norm=False, - upsample_mode='simple', - dropout=0.1, - dropout_change_per_layer=0.1, - dropout_type='none', - use_dropout_on_upsampling=False, - kernel_initializer=self.kernel_init, - use_attention_gate=False, - filters=16, - num_layers=4, - # output_activation=None, - ) + + if self.gen_i2s_typ == 'resnet': + self.gen_IS = get_resnet_generator( + input_img_size=self.subvol_patch_size, + batch_size=self.global_batch_size, + name='generator_IS', + num_downsampling_blocks=3, + num_upsample_blocks=3 + ) + elif self.gen_i2s_typ == 'vnet': + self.gen_IS = custom_vnet( + input_shape=self.subvol_patch_size, + activation='relu', + use_batch_norm=False, + upsample_mode='upsample', + dropout=0.5, + dropout_change_per_layer=0.0, + dropout_type='spatial', + use_dropout_on_upsampling=False, + use_attention_gate=False, + filters=32, + num_layers=4, + output_activation='tanh', + ) + elif self.gen_i2s_typ == 'resUnet': + self.gen_IS = ResUNet( + input_shape=self.subvol_patch_size, + upsample_mode='simple', + dropout=0.1, + dropout_change_per_layer=0.1, + dropout_type='none', + use_attention_gate=False, + filters=16, + num_layers=4, + # output_activation=None, + ) else: raise ValueError('AB Generator type not recognised') - - if self.genBA_typ == 'resnet': - self.gen_BA = get_resnet_generator( - input_img_size=self.subvol_patch_size, - batch_size=self.global_batch_size, - # gamma_initializer=self.gamma_init, - #kernel_initializer=self.kernel_init, - name='generator_BA', - num_downsampling_blocks=3, - num_upsample_blocks=3 - ) - elif self.genBA_typ == 'vnet': - self.gen_BA = custom_vnet( - input_shape = self.subvol_patch_size, - num_classes=1, - activation='relu', - use_batch_norm=True, - upsample_mode='deconv', - dropout=0.5, - dropout_change_per_layer=0.0, - dropout_type='spatial', - use_dropout_on_upsampling=False, - kernel_initializer=self.kernel_init, - use_attention_gate=False, - filters=16, - num_layers=4, - output_activation='tanh', - addnoise=False - ) - elif self.genBA_typ == 'resUnet': - self.gen_BA = ResUNet( - input_shape = self.seg_subvol_patch_size, - num_classes=1, - activation='relu', - use_batch_norm=False, - upsample_mode='simple', - dropout=0.1, - dropout_change_per_layer=0.1, - dropout_type='none', - use_dropout_on_upsampling=False, - kernel_initializer=self.kernel_init, - use_attention_gate=False, - filters=16, - num_layers=4, - # output_activation=None, - use_input_noise=False - ) + + if self.gen_s2i_typ == 'resnet': + self.gen_SI = get_resnet_generator( + input_img_size=self.subvol_patch_size, + batch_size=self.global_batch_size, + name='generator_SI', + num_downsampling_blocks=3, + num_upsample_blocks=3 + ) + elif self.gen_s2i_typ == 'vnet': + self.gen_SI = custom_vnet( + input_shape=self.subvol_patch_size, + activation='relu', + use_batch_norm=True, + upsample_mode='deconv', + dropout=0.5, + dropout_change_per_layer=0.0, + dropout_type='spatial', + use_dropout_on_upsampling=False, + use_attention_gate=False, + filters=16, + num_layers=4, + output_activation='tanh', + addnoise=False + ) + elif self.gen_s2i_typ == 'resUnet': + self.gen_SI = ResUNet( + input_shape=self.seg_subvol_patch_size, + upsample_mode='simple', + dropout=0.1, + dropout_change_per_layer=0.1, + dropout_type='none', + use_attention_gate=False, + filters=16, + num_layers=4, + # output_activation=None, + use_input_noise=False + ) else: raise ValueError('BA Generator type not recognised') - - + # Get the discriminators - self.disc_A = get_discriminator( - input_img_size=self.subvol_patch_size, - batch_size=self.global_batch_size, - kernel_initializer=self.kernel_init, - name='discriminator_A', - filters=64, - use_dropout=False, - wasserstein=self.wasserstein, - use_SN=False, - use_input_noise=True, - use_layer_noise=True, - noise_std=self.layer_noise - ) - self.disc_B = get_discriminator( - input_img_size=self.seg_subvol_patch_size, - batch_size=self.global_batch_size, - kernel_initializer=self.kernel_init, - name='discriminator_B', - filters=64, - use_dropout=False, - wasserstein=self.wasserstein, - use_SN=False, - use_input_noise=True, - use_layer_noise=True, - noise_std=self.layer_noise - ) - - + self.disc_I = get_discriminator( + input_img_size=self.subvol_patch_size, + batch_size=self.global_batch_size, + name='discriminator_I', + filters=64, + use_dropout=False, + wasserstein=self.wasserstein, + use_SN=False, + use_input_noise=True, + use_layer_noise=True, + noise_std=self.layer_noise + ) + self.disc_S = get_discriminator( + input_img_size=self.seg_subvol_patch_size, + batch_size=self.global_batch_size, + name='discriminator_S', + filters=64, + use_dropout=False, + wasserstein=self.wasserstein, + use_SN=False, + use_input_noise=True, + use_layer_noise=True, + noise_std=self.layer_noise + ) + # Initialise optimizers if self.wasserstein: - - self.gen_A_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., beta_2=0.9)#, clipnorm=10.0) - self.gen_B_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., beta_2=0.9)#, clipnorm=10.0) - self.disc_A_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., beta_2=0.9)#, clipnorm=10.0) - self.disc_B_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., beta_2=0.9)#, clipnorm=10.0) - + + self.gen_I_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., + beta_2=0.9) # , clipnorm=10.0) + self.gen_S_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., + beta_2=0.9) # , clipnorm=10.0) + self.disc_I_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., + beta_2=0.9) # , clipnorm=10.0) + self.disc_S_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0., + beta_2=0.9) # , clipnorm=10.0) + else: # Initialise decay rates - self.dA_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( - 2e-4, - decay_steps=5*self.train_steps, - decay_rate=0.98, - staircase=False) - - self.dB_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( - 2e-4, - decay_steps=5*self.train_steps, - decay_rate=0.98, - staircase=False) - - self.gen_A_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, - beta_1=0.5, - beta_2=0.9, - clipnorm=100) - self.gen_B_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, - beta_1=0.5, - beta_2=0.9, - clipnorm=100) - self.disc_A_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, - beta_1=0.5, - beta_2=0.9, - clipnorm=100) - self.disc_B_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, - beta_1=0.5, - beta_2=0.9, - clipnorm=100) - + self.dI_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + 2e-4, + decay_steps=5 * self.train_steps, + decay_rate=0.98, + staircase=False) + + self.dS_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + 2e-4, + decay_steps=5 * self.train_steps, + decay_rate=0.98, + staircase=False) + + self.gen_I_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, + beta_1=0.5, + beta_2=0.9, + clipnorm=100) + self.gen_S_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, + beta_1=0.5, + beta_2=0.9, + clipnorm=100) + self.disc_I_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, + beta_1=0.5, + beta_2=0.9, + clipnorm=100) + self.disc_S_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, + beta_1=0.5, + beta_2=0.9, + clipnorm=100) + # Initialise checkpoint - self.checkpoint = tf.train.Checkpoint(gen_AB=self.gen_AB, - gen_BAF=self.gen_BA, - disc_A=self.disc_A, - disc_B=self.disc_B, - gen_A_optimizer=self.gen_A_optimizer, - gen_B_optimizer=self.gen_B_optimizer, - disc_A_optimizer=self.disc_A_optimizer, - disc_B_optimizer=self.disc_B_optimizer) - + self.checkpoint = tf.train.Checkpoint(gen_AB=self.gen_IS, + gen_BAF=self.gen_SI, + disc_A=self.disc_I, + disc_B=self.disc_S, + gen_A_optimizer=self.gen_I_optimizer, + gen_B_optimizer=self.gen_S_optimizer, + disc_A_optimizer=self.disc_I_optimizer, + disc_B_optimizer=self.disc_S_optimizer) + def save_checkpoint(self, epoch): """ save checkpoint to checkpoint_dir, overwrite if exists """ - self.checkpoint.write(self.checkpoint_prefix+"_e{epoch}".format(epoch=epoch+1)) + self.checkpoint.write(self.checkpoint_prefix + "_e{epoch}".format(epoch=epoch + 1)) print(f'\nSaved checkpoint to {self.checkpoint_prefix}\n') def load_checkpoint(self, epoch=None, expect_partial: bool = False, newpath=None): """ load checkpoint from checkpoint_dir if exists """ if newpath is not None: self.checkpoint_prefix = os.path.join(newpath, 'checkpoint') - checkpoint_path = self.checkpoint_prefix +"_e{epoch}".format(epoch=epoch) + checkpoint_path = self.checkpoint_prefix + "_e{epoch}".format(epoch=epoch) if os.path.exists(f'{os.path.join(checkpoint_path)}.index'): self.checkpoint_loaded = True with self.strategy.scope(): - if expect_partial: - self.checkpoint.read(checkpoint_path).expect_partial() - else: - self.checkpoint.read(checkpoint_path) + if expect_partial: + self.checkpoint.read(checkpoint_path).expect_partial() + else: + self.checkpoint.read(checkpoint_path) print(f'\nLoaded checkpoint from {checkpoint_path}\n') else: print('Error: Checkpoint not found!') - - def addNoise(self, img, rate): - return tf.clip_by_value(img + tf.random.normal(tf.shape(img), 0.0, rate), -1., 1.) - - def binarise_tensor(self, arr): - return tf.where(tf.math.greater_equal(arr, tf.zeros(tf.shape(arr))), - tf.ones(tf.shape(arr)), - tf.math.negative(tf.ones(tf.shape(arr)))) - - def computeLosses(self, real_A, real_B, result, training=True): + + def compute_losses(self, real_I, real_S, result, training=True): """ Computes the losses for the VANGAN model using the given input images and model settings. Args: - real_A (tf.Tensor): A tensor containing the real images from domain A. - real_B (tf.Tensor): A tensor containing the real images from domain B. + real_I (tf.Tensor): A tensor containing the real images from the imaging domain. + real_S (tf.Tensor): A tensor containing the real images from the segmentation domain. result (dict): A dictionary to store the loss values. training (bool, optional): A flag indicating whether the model is being trained or not. Defaults to True. @@ -296,88 +271,85 @@ def computeLosses(self, real_A, real_B, result, training=True): tuple: A tuple containing the updated result dictionary and the calculated losses. Raises: - ValueError: If the `cycle_loss_fn`, `seg_loss_fn`, `perceptual_loss`, `discriminator_loss_fn`, + ValueError: If the `cycle_loss_fn`, `seg_loss_fn`, `reconstruction_loss`, `discriminator_loss_fn`, `generator_loss_fn`, `wasserstein_discriminator_loss` or `wasserstein_generator_loss` are not callable functions. """ - + # Can be used to debug dataset numerics - #tf.debugging.check_numerics(real_A, 'real_A failure') - #tf.debugging.check_numerics(real_B, 'real_B failure') - + # tf.debugging.check_numerics(real_I, 'real_I failure') + # tf.debugging.check_numerics(real_S, 'real_S failure') + # A -> B - fake_B = self.gen_AB(real_A, training=training) + fake_S = self.gen_IS(real_I, training=training) # B -> A - fake_A = self.gen_BA(real_B, training=training) - + fake_I = self.gen_SI(real_S, training=training) + # Cycle loss - cycled_B = self.gen_AB(fake_A, training=training) - - cycle_loss_A = self.cycle_loss_fn(self, real_B, cycled_B, typ="bce") - - - seg_loss = self.seg_loss_fn(self, real_B, cycled_B) - cycled_A = self.gen_BA(fake_B, training=training) - cycle_loss_B = self.cycle_loss_fn(self, real_A, cycled_A, typ='L2') - - perceptualA_loss = self.perceptual_loss(self, real_A, cycled_A) - + cycled_S = self.gen_IS(fake_I, training=training) + + cycle_loss_I = self.cycle_loss_fn(self, real_S, cycled_S, typ="bce") + + seg_loss = self.seg_loss_fn(self, real_S, cycled_S) + cycled_I = self.gen_SI(fake_S, training=training) + cycle_loss_S = self.cycle_loss_fn(self, real_I, cycled_I, typ='L2') + + reconstruction_loss_I = self.reconstruction_loss(self, real_I, cycled_I) + # Identity mapping - # id_BA_loss = self.identity_loss_fn(self, real_A, self.gen_BA(real_A, training=True)) - # id_AB_loss = self.identity_loss_fn(self, real_B, self.gen_AB(real_B, training=True), typ='cldice') - - + # id_SI_loss = self.identity_loss_fn(self, real_I, self.gen_SI(real_I, training=True)) + # id_IS_loss = self.identity_loss_fn(self, real_S, self.gen_IS(real_S, training=True), typ='cldice') + # Discriminator outputs - disc_real_B = self.disc_B(real_B, training=training) - disc_fake_B = self.disc_B(fake_B, training=training) - - disc_real_A = self.disc_A(real_A, training=training) - disc_fake_A = self.disc_A(fake_A, training=training) - + disc_real_S = self.disc_S(real_S, training=training) + disc_fake_S = self.disc_S(fake_S, training=training) + + disc_real_I = self.disc_I(real_I, training=training) + disc_fake_I = self.disc_I(fake_I, training=training) # Generator & discriminator loss if self.wasserstein: - gen_AB_loss = self.wasserstein_generator_loss(self, disc_fake_B) - gen_BA_loss = self.wasserstein_generator_loss(self, disc_fake_A) - disc_A_loss = self.wasserstein_discriminator_loss(self, disc_real_A, disc_fake_A) - disc_B_loss = self.wasserstein_discriminator_loss(self, disc_real_B, disc_fake_B) - + gen_IS_loss = self.wasserstein_generator_loss(self, disc_fake_S) + gen_SI_loss = self.wasserstein_generator_loss(self, disc_fake_I) + disc_I_loss = self.wasserstein_discriminator_loss(self, disc_real_I, disc_fake_I) + disc_S_loss = self.wasserstein_discriminator_loss(self, disc_real_S, disc_fake_S) + else: - gen_AB_loss = self.generator_loss_fn(self, disc_fake_B, from_logits=True) - gen_BA_loss = self.generator_loss_fn(self, disc_fake_A, from_logits=True) - disc_A_loss = self.discriminator_loss_fn(self, disc_real_A, disc_fake_A, from_logits=True) - disc_B_loss = self.discriminator_loss_fn(self, disc_real_B, disc_fake_B, from_logits=True) - + gen_IS_loss = self.generator_loss_fn(self, disc_fake_S, from_logits=True) + gen_SI_loss = self.generator_loss_fn(self, disc_fake_I, from_logits=True) + disc_I_loss = self.discriminator_loss_fn(self, disc_real_I, disc_fake_I, from_logits=True) + disc_S_loss = self.discriminator_loss_fn(self, disc_real_S, disc_fake_S, from_logits=True) + # Total generator loss - total_loss_A = gen_AB_loss + cycle_loss_A + seg_loss #+ id_BA_loss - total_loss_B = gen_BA_loss + cycle_loss_B + perceptualA_loss #+ id_AB_loss - + total_loss_I = gen_IS_loss + cycle_loss_I + seg_loss # + id_SI_loss + total_loss_S = gen_SI_loss + cycle_loss_S + reconstruction_loss_I # + id_IS_loss + result.update({ - 'total_AB_loss': total_loss_A, - 'total_BA_loss': total_loss_B, - 'D_A_loss': disc_A_loss, - 'D_B_loss': disc_B_loss, - 'gen_AB_loss': gen_AB_loss, - 'gen_BA_loss': gen_BA_loss, - 'cycle_gen_AB_loss': cycle_loss_A, - 'cycle_gen_BA_loss': cycle_loss_B, + 'total_AB_loss': total_loss_I, + 'total_BA_loss': total_loss_S, + 'D_A_loss': disc_I_loss, + 'D_B_loss': disc_S_loss, + 'gen_IS_loss': gen_IS_loss, + 'gen_SI_loss': gen_SI_loss, + 'cycle_gen_AB_loss': cycle_loss_I, + 'cycle_gen_BA_loss': cycle_loss_S, 'seg_loss': seg_loss, - 'perceptualA_loss': perceptualA_loss, - # 'identity_AB': id_AB_loss, - # 'identity_BA': id_BA_loss + 'reconstruction_loss_I': reconstruction_loss_I, + # 'identity_AB': id_IS_loss, + # 'identity_BA': id_SI_loss }) - - return result, total_loss_A, total_loss_B, disc_A_loss, disc_B_loss, fake_A, fake_B - - def gradient_penalty(self, real, fake, descrip='A'): + + return result, total_loss_I, total_loss_S, disc_I_loss, disc_S_loss, fake_I, fake_S + + def gradient_penalty(self, real, fake, descrip='I'): """ Computes the gradient penalty for the Wasserstein loss function. Parameters: - real: the real input data (either A or B) with dimensions [batch_size, height, width, channels] - fake: the generated data (either A or B) with dimensions [batch_size, height, width, channels] - - descrip: specifies which discriminator to use (either 'A' or 'B') + - descrip: specifies which discriminator to use (either 'I' or 'S') Returns: - gp: the computed gradient penalty @@ -385,94 +357,95 @@ def gradient_penalty(self, real, fake, descrip='A'): alpha = tf.random.normal([self.batch_size, 1, 1, 1, 1], 0.0, 1.0) diff = fake - real interpolated = real + alpha * diff - if descrip == 'A': - pred = self.disc_A(interpolated, training=True) + if descrip == 'I': + pred = self.disc_I(interpolated, training=True) else: - pred = self.disc_B(interpolated, training=True) - grads = tf.gradients(pred, interpolated)[0] - norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3, 4]) + 1.e-12) # small constant add to prevent division by zero + pred = self.disc_S(interpolated, training=True) + grads = tf.gradients(pred, interpolated)[0] + norm = tf.sqrt(tf.reduce_sum(tf.square(grads), + axis=[1, 2, 3, 4]) + 1.e-12) # small constant add to prevent division by zero gp = reduce_mean(self, (norm - 1.0) ** 2) return gp - - def train_step(self, real_A, real_B): + + def train_step(self, real_I, real_S): """ Trains the VANGAN model using a single batch of input images. Parameters: - `self`: the VANGAN object. - - `real_A`: a batch of images from domain A. - - `real_B`: a batch of images from domain B. + - `real_I`: a batch of images from the imaging domain. + - `real_S`: a batch of images from the segmentation domain. Returns: - `result`: a dictionary containing the losses and metrics computed during training. """ result = {} - with tf.GradientTape(persistent=True) as tape: - result, total_loss_A, total_loss_B, disc_A_loss, disc_B_loss, fake_A, fake_B = self.computeLosses(real_A, real_B, result, training=True) + with tf.GradientTape(persistent=True) as tape: + result, total_loss_I, total_loss_S, disc_I_loss, disc_S_loss, fake_I, fake_S = self.compute_losses(real_I, + real_S, + result, + training=True) if self.wasserstein: - + if self.updateGen: - self.gen_A_optimizer.minimize(loss=total_loss_A, - var_list=self.gen_AB.trainable_variables, + self.gen_I_optimizer.minimize(loss=total_loss_I, + var_list=self.gen_IS.trainable_variables, tape=tape) - self.gen_B_optimizer.minimize(loss=total_loss_B, - var_list=self.gen_BA.trainable_variables, + self.gen_S_optimizer.minimize(loss=total_loss_S, + var_list=self.gen_SI.trainable_variables, tape=tape) - - if self.initModel == False: - gp = self.gradient_penalty(real_A, fake_A, descrip='A') - disc_A_loss = disc_A_loss + gp * self.gp_weight - - gp = self.gradient_penalty(real_B, fake_B, descrip='B') - disc_B_loss = disc_B_loss + gp * self.gp_weight - - + + if not self.initModel: + gp = self.gradient_penalty(real_I, fake_I, descrip='A') + disc_I_loss = disc_I_loss + gp * self.gp_weight + + gp = self.gradient_penalty(real_S, fake_S, descrip='B') + disc_S_loss = disc_S_loss + gp * self.gp_weight + # clipping weights of discriminators as told in the # WasserteinGAN paper to enforce Lipschitz constraint. # clip_values = [-0.01, 0.01] # self.clip_discriminator_A_var_op = [var.assign(tf.clip_by_value(var, clip_values[0], clip_values[1])) for - # var in self.disc_A.trainable_variables] + # var in self.disc_I.trainable_variables] # self.clip_discriminator_B_var_op = [var.assign(tf.clip_by_value(var, clip_values[0], clip_values[1])) for - # var in self.disc_B.trainable_variables] - - + # var in self.disc_S.trainable_variables] + else: - self.gen_A_optimizer.minimize(loss=total_loss_A, - var_list=self.gen_AB.trainable_variables, + self.gen_I_optimizer.minimize(loss=total_loss_I, + var_list=self.gen_IS.trainable_variables, tape=tape) - self.gen_B_optimizer.minimize(loss=total_loss_B, - var_list=self.gen_BA.trainable_variables, + self.gen_S_optimizer.minimize(loss=total_loss_S, + var_list=self.gen_SI.trainable_variables, tape=tape) - - - self.disc_A_optimizer.minimize(loss=disc_A_loss, - var_list=self.disc_A.trainable_variables, - tape=tape) - self.disc_B_optimizer.minimize(loss=disc_B_loss, - var_list=self.disc_B.trainable_variables, - tape=tape) - + + self.disc_I_optimizer.minimize(loss=disc_I_loss, + var_list=self.disc_I.trainable_variables, + tape=tape) + self.disc_S_optimizer.minimize(loss=disc_S_loss, + var_list=self.disc_S.trainable_variables, + tape=tape) + return result - - def test_step(self, real_A, real_B): + + def test_step(self, real_I, real_S): """ Evaluates the VANGAN model on a single batch of input images. Parameters: - `self`: the VANGAN object. - - `real_A`: a batch of images from domain A. - - `real_B`: a batch of images from domain B. + - `real_I`: a batch of images from the imaging domain. + - `real_S`: a batch of images from the segmentation domain. Returns: - `result`: a dictionary containing the losses and metrics computed during evaluation. """ result = {} - result, _, _, _, _, _, _ = self.computeLosses(real_A, real_B, result, training=False) + result, _, _, _, _, _, _ = self.compute_losses(real_I, real_S, result, training=False) return result - + def reduce_dict(self, d: dict): """ Reduces the values in a dictionary using the current distribution strategy. @@ -487,8 +460,8 @@ def reduce_dict(self, d: dict): ''' reduce items in dictionary d ''' for k, v in d.items(): - d[k] = self.strategy.reduce(tf.distribute.ReduceOp.SUM, v, axis=None) - + d[k] = self.strategy.reduce(tf.distribute.ReduceOp.SUM, v, axis=None) + @tf.function def distributed_train_step(self, x, y): """ @@ -496,8 +469,8 @@ def distributed_train_step(self, x, y): Parameters: - `self`: the VANGAN object. - - `x`: a batch of images from domain A. - - `y`: a batch of images from domain B. + - `x`: a batch of images from the imaging domain. + - `y`: a batch of images from the segmentation domain. Returns: - `results`: a dictionary containing the losses and metrics computed during training. @@ -505,7 +478,7 @@ def distributed_train_step(self, x, y): results = self.strategy.run(self.train_step, args=(x, y)) self.reduce_dict(results) return results - + @tf.function def distributed_test_step(self, x, y): """ @@ -513,8 +486,8 @@ def distributed_test_step(self, x, y): Parameters: - `self`: the VANGAN object. - - `x`: a batch of images from domain A. - - `y`: a batch of images from domain B. + - `x`: a batch of images from the imaging domain. + - `y`: a batch of images from the segmentation domain. Returns: - `results`: a dictionary containing the losses and metrics computed during testing. @@ -523,7 +496,8 @@ def distributed_test_step(self, x, y): self.reduce_dict(results) return results -def train(args, ds, gan, summary, epoch: int, steps=None, desc=None, training=True): + +def train(ds, gan, summary, epoch: int, steps=None, desc=None, training=True): """ Runs a training or testing loop for a given number of steps using the specified VANGAN object and dataset. @@ -541,7 +515,7 @@ def train(args, ds, gan, summary, epoch: int, steps=None, desc=None, training=Tr - `results`: a dictionary containing the losses and metrics computed during training or testing. """ results = {} - cntr = 0 + cntr = 0 for x, y in tqdm(ds, desc=desc, total=steps, disable=0): if cntr == steps: break @@ -549,7 +523,7 @@ def train(args, ds, gan, summary, epoch: int, steps=None, desc=None, training=Tr cntr += 1 if training: if gan.icritic % gan.ncritic == 0: - gan.updateGen = True + gan.updateGen = True gan.icritic = 1 else: gan.icritic += 1 @@ -562,6 +536,5 @@ def train(args, ds, gan, summary, epoch: int, steps=None, desc=None, training=Tr for key, value in results.items(): summary.scalar(key, tf.reduce_mean(value), epoch=epoch, training=training) - + return results - diff --git a/vnet_model.py b/vnet_model.py index 2f80297..8cbc9a8 100644 --- a/vnet_model.py +++ b/vnet_model.py @@ -12,7 +12,7 @@ multiply, add, Activation, - ) +) import tensorflow_addons as tfa import tensorflow as tf from building_blocks import ReflectionPadding3D @@ -20,10 +20,20 @@ '''https://github.com/karolzak/keras-unet''' + def attention_gate(inp_1, inp_2, n_intermediate_filters): - '''Attention gate. Compresses both inputs to n_intermediate_filters filters before processing. - Implemented as proposed by Oktay et al. in their Attention U-net, see: https://arxiv.org/abs/1804.03999. - ''' + """ + Attention gate. Compresses both inputs to `n_intermediate_filters` filters before processing. + Implemented as proposed by Oktay et al. in their Attention U-Net, see: https://arxiv.org/abs/1804.03999. + + Args: + inp_1 (tf.Tensor): First input tensor to the attention gate. + inp_2 (tf.Tensor): Second input tensor to the attention gate (skip-connection). + n_intermediate_filters (int): Number of intermediate filters to use in the attention gate. + + Returns: + (tf.Tensor): Output tensor after applying the attention gate. + """ inp_1_conv = Conv3D( n_intermediate_filters, kernel_size=1, @@ -52,29 +62,55 @@ def attention_gate(inp_1, inp_2, n_intermediate_filters): def attention_concat(conv_below, skip_connection): - '''Performs concatenation of upsampled conv_below with attention gated version of skip-connection - ''' + """ + Concatenates the upsampled `conv_below` with the attention-gated version of `skip_connection`. + + Args: + conv_below (tf.Tensor): The upsampled tensor that will be concatenated. + skip_connection (tf.Tensor): The skip-connection tensor used for attention gating. + + Returns: + (tf.Tensor): Output tensor after concatenation with attention gating. + """ below_filters = conv_below.get_shape().as_list()[-1] attention_across = attention_gate(skip_connection, conv_below, below_filters) return concatenate([conv_below, attention_across]) def conv3d_block( - inputs, - use_batch_norm=True, - dropout=0.3, - dropout_type='spatial', - filters=16, - kernel_size=(3, 3, 3), - activation='relu', - kernel_initializer='he_normal', - padding='valid', + inputs, + use_batch_norm=True, + dropout=0.3, + dropout_type='spatial', + filters=16, + kernel_size=(3, 3, 3), + activation='relu', + kernel_initializer='he_normal', + padding='valid', ): + """ + Create a 3D convolutional block consisting of two convolutional layers with optional batch normalization and + dropout. + + Args: + inputs (tf.Tensor): Input tensor to the convolutional block. + use_batch_norm (bool): Whether to use batch normalization. + dropout (float): Dropout rate for spatial or standard dropout (if enabled). + dropout_type (str): Type of dropout, either 'spatial' or 'standard'. + filters (int): Number of filters (output channels) in the convolutional layers. + kernel_size (tuple): Size of the convolutional kernel in 3D (depth, height, width). + activation (str): Activation function to be used after convolution. + kernel_initializer (str): Initializer for the convolutional kernel weights. + padding (str): Padding mode for the convolutional layers. + + Returns: + (tf.Tensor): Output tensor after passing through the 3D convolutional block. + """ if dropout_type == 'spatial': - DO = SpatialDropout3D + do = SpatialDropout3D elif dropout_type == 'standard': - DO = Dropout + do = Dropout else: raise ValueError( f"dropout_type must be one of ['spatial', 'standard'], got {dropout_type}" @@ -93,7 +129,7 @@ def conv3d_block( else: c = tfa.layers.InstanceNormalization()(c) if dropout > 0.0: - c = DO(dropout)(c) + c = do(dropout)(c) c = ReflectionPadding3D()(c) c = Conv3D( filters, @@ -111,43 +147,41 @@ def conv3d_block( def custom_vnet( - input_shape, - num_classes=1, - activation='relu', - use_batch_norm=True, - upsample_mode='deconv', # 'deconv' or 'simple' - dropout=0.5, - dropout_change_per_layer=0.0, - dropout_type='spatial', - use_dropout_on_upsampling=False, - kernel_initializer='he_normal', - gamma_initializer='he_normal', - use_attention_gate=False, - filters=16, - num_layers=4, - output_activation='sigmoid', - addnoise=False + input_shape, + num_classes=1, + activation='relu', + use_batch_norm=True, + upsample_mode='deconv', # 'deconv' or 'simple' + dropout=0.5, + dropout_change_per_layer=0.0, + dropout_type='spatial', + use_dropout_on_upsampling=False, + kernel_initializer='he_normal', + use_attention_gate=False, + filters=16, + num_layers=4, + output_activation='sigmoid', + addnoise=False ): # 'sigmoid' or 'softmax' - ''' + """ Customizable VNet architecture based on the work of Fausto Milletari, Nassir Navab, Seyed-Ahmad Ahmadi in V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation - Arguments: - input_shape: 4D Tensor of shape (x, y, z, num_channels) - num_classes (int): Unique classes in the output mask. Should be set to 1 for binary segmentation - activation (str): A keras.activations.Activation to use. ReLu by default. - use_batch_norm (bool): Whether to use Batch Normalisation across the channel axis between convolutional layers - upsample_mode (one of 'deconv' or 'simple'): Whether to use transposed convolutions or simple upsampling in the decoder part - dropout (float between 0. and 1.): Amount of dropout after the initial convolutional block. Set to 0. to turn Dropout off - dropout_change_per_layer (float between 0. and 1.): Factor to add to the Dropout after each convolutional block - dropout_type (one of 'spatial' or 'standard'): Type of Dropout to apply. Spatial is recommended for CNNs [2] - use_dropout_on_upsampling (bool): Whether to use dropout in the decoder part of the network - use_attention_gate (bool): Whether to use an attention dynamic when concatenating with the skip-connection, implemented as proposed by Oktay et al. [3] - filters (int): Convolutional filters in the initial convolutional block. Will be doubled every block - num_layers (int): Number of total layers in the encoder not including the bottleneck layer - output_activation (str): A keras.activations.Activation to use. Sigmoid by default for binary segmentation + Arguments: input_shape: 4D Tensor of shape (x, y, z, num_channels) num_classes (int): Unique classes in the + output mask. Should be set to 1 for binary segmentation activation (str): A keras.activations.Activation to use. + ReLu by default. use_batch_norm (bool): Whether to use Batch Normalisation across the channel axis between + convolutional layers upsample_mode (one of 'deconv' or 'simple'): Whether to use transposed convolutions or + simple upsampling in the decoder part dropout (float between 0. and 1.): Amount of dropout after the initial + convolutional block. Set to 0. to turn Dropout off dropout_change_per_layer (float between 0. and 1.): Factor to + add to the Dropout after each convolutional block dropout_type (one of 'spatial' or 'standard'): Type of Dropout + to apply. Spatial is recommended for CNNs [2] use_dropout_on_upsampling (bool): Whether to use dropout in the + decoder part of the network use_attention_gate (bool): Whether to use an attention dynamic when concatenating + with the skip-connection, implemented as proposed by Oktay et al. [3] filters (int): Convolutional filters in the + initial convolutional block. Will be doubled every block num_layers (int): Number of total layers in the encoder + not including the bottleneck layer output_activation (str): A keras.activations.Activation to use. Sigmoid by + default for binary segmentation Returns: model (keras.models.Model): The built V-Net @@ -160,22 +194,22 @@ def custom_vnet( [2]: https://arxiv.org/pdf/1411.4280.pdf [3]: https://arxiv.org/abs/1804.03999 - ''' + """ # Build model inputs = Input(input_shape) x = inputs - + if addnoise: x = min_max_norm_tf(x) + tf.random.normal(shape=tf.shape(x), - mean=-0.475, - stddev=0.06) + mean=-0.475, + stddev=0.06) x = tf.math.add(x, inputs) x = tf.clip_by_value(x, 0., 1.) x = rescale_arr_tf(x, -0.5, 0.5) down_layers = [] - for l in range(num_layers): + for layer in range(num_layers): x = conv3d_block( inputs=x, filters=filters, @@ -184,7 +218,7 @@ def custom_vnet( dropout_type=dropout_type, kernel_initializer=kernel_initializer, activation=activation, - ) + ) down_layers.append(x) x = MaxPooling3D((2, 2, 2))(x) dropout += dropout_change_per_layer @@ -231,4 +265,4 @@ def custom_vnet( model = Model(inputs=[inputs], outputs=[outputs]) model.summary() - return model \ No newline at end of file + return model