Skip to content

Commit

Permalink
Minor documentation updates and code optimisations. Code now outputs …
Browse files Browse the repository at this point in the history
…'args' settings in main. Also, include a function to perform clahe to 2D slices of a 3D image. (#4)
  • Loading branch information
psweens authored Oct 2, 2023
1 parent bdb2fff commit 5fbd02d
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 175 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Tensorflow and the remaining Python packages below can be installed in a [_conda

The remaining required Python packages can be installed using _pip_ in a terminal window:
```bash
pip install opencv-python scikit-image tqdm tensorflow_addons tensorflow-mri joblib matplotlib
pip install opencv-python scikit-image tqdm tensorflow_addons joblib matplotlib
```

VAN-GAN has been tested on Ubuntu 22.04.2 LTS with Python 3.9.16 and the following package versions:
Expand All @@ -59,10 +59,10 @@ VAN-GAN code was originally developed by [Paul W. Sweeney](https://www.psweeney.

Please get in contact in you have any questions.

## References
## Citation
If you use this code or data, we kindly ask that you please cite the below:
> [Segmentation of 3D blood vessel networks using unsupervised deep learning](https://doi.org/10.1101/2023.04.30.538453)<br>
> Paul W. Sweeney et al.
> Paul W. Sweeney et al. *bioRxiv*
## Licence
The project is licenced under the MIT Licence.
Expand Down
29 changes: 12 additions & 17 deletions custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,11 @@ def __init__(self,
self.imaging_val_data = imaging_val_data
self.segmentation_val_data = segmentation_val_data
self.process_imaging_domain = process_imaging_domain
self.period = args.PERIOD_2D_CALLBACK,
self.period3D = args.PERIOD_3D_CALLBACK,
self.model_path = args.output_dir,
self.period = args.PERIOD_2D_CALLBACK
self.period3D = args.PERIOD_3D_CALLBACK
self.model_path = args.output_dir
self.dims = args.DIMENSIONS

self.period = self.period[0]
self.period3D = self.period3D[0]
self.model_path = self.model_path[0]

def save_model(self, model, epoch):
"""Save the trained model at the given epoch.
Expand All @@ -43,10 +39,10 @@ def save_model(self, model, epoch):
"""

# if epoch > 100:
model.gen_AB.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genAB".format(epoch=epoch + 1)))
model.gen_BA.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genBA".format(epoch=epoch + 1)))
model.disc_A.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discA".format(epoch=epoch + 1)))
model.disc_B.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discB".format(epoch=epoch + 1)))
model.gen_IS.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genAB".format(epoch=epoch + 1)))
model.gen_SI.save(os.path.join(self.model_path, "checkpoints/e{epoch}_genBA".format(epoch=epoch + 1)))
model.disc_I.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discA".format(epoch=epoch + 1)))
model.disc_S.save(os.path.join(self.model_path, "checkpoints/e{epoch}_discB".format(epoch=epoch + 1)))

def stitch_subvolumes(self, gen, img, subvol_size,
epoch=-1, stride=(25, 25, 128),
Expand Down Expand Up @@ -173,7 +169,7 @@ def stitch_subvolumes(self, gen, img, subvol_size,
start_dep:(start_dep + kD)]

if process_img and self.process_imaging_domain is not None:
arr = self.process_imaging_domain(arr)
arr = self.process_imaging_domain(arr, axis=None, keepdims=False)

arr = gen(np.expand_dims(arr,
axis=0), training=False)[0]
Expand Down Expand Up @@ -419,14 +415,13 @@ def updateDiscriminatorNoise(self, model, init_noise, epoch, args):
else:
decay_rate = epoch / args.NO_NOISE
noise = init_noise * (1. - decay_rate)
if noise < 0.0:
noise = 0.0
# noise = 0.9 ** (epoch + 1)
print('Noise std: %0.5f' % noise)
for layer in model.layers:
if type(layer) == layers.GaussianNoise:
if noise > 0.:
layer.stddev = noise
else:
layer.stddev = 0.0
if isinstance(layer, tf.keras.layers.GaussianNoise):
layer.stddev = noise

def on_epoch_start(self, model, epoch, args, logs=None):
"""
Expand Down
Loading

0 comments on commit 5fbd02d

Please sign in to comment.