From a45d037bdb125c1211f2a5153d4080a84ae4c489 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 14:05:37 +0100 Subject: [PATCH 01/23] Surface-Dice functionalities --- .../segmentation/training/surface_dice.py | 318 ++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 src/membrain_seg/segmentation/training/surface_dice.py diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py new file mode 100644 index 0000000..f0c6868 --- /dev/null +++ b/src/membrain_seg/segmentation/training/surface_dice.py @@ -0,0 +1,318 @@ +""" +Adapted from: clDice - A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation +Original Authors: Johannes C. Paetzold and Suprosanna Shit +Sources: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py +License: MIT License + +The following code is a modification of the original clDice implementation. Modifications were made to +include additional functionality and integrate with new project requirements. The original license and +copyright notice are provided below. + +MIT License + +Copyright (c) 2021 Johannes C. Paetzold and Suprosanna Shit + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +import torch.nn.functional as F +from .soft_skeleton import soft_skel +from torch.nn.functional import sigmoid +from torch.nn.modules.loss import _Loss +from scipy.ndimage import gaussian_filter + + + +def soft_erode(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: + """ + Apply soft erosion operation to the input image. + + Soft erosion is achieved by applying a min-pooling operation to the input image. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W) + separate_pool : bool, optional + If True, perform separate 3D max-pooling operations along different axes. + Default is False. + + Returns + ------- + torch.Tensor + Eroded image tensor with the same shape as the input. + + Raises + ------ + ValueError + If the input tensor has an unsupported number of dimensions. + + Notes + ----- + - The soft erosion can be performed with separate 3D min-pooling operations + along different axes if separate_pool is True, or with a single + 3D min-pooling operation with a kernel of size (3, 3, 3) if + separate_pool is False. + """ + assert len(img.shape)==5 + if separate_pool: + p1 = -F.max_pool3d(-img,(3,1,1),(1,1,1),(1,0,0)) + p2 = -F.max_pool3d(-img,(1,3,1),(1,1,1),(0,1,0)) + p3 = -F.max_pool3d(-img,(1,1,3),(1,1,1),(0,0,1)) + return torch.min(torch.min(p1, p2), p3) + p4 = -F.max_pool3d(-img,(3,3,3),(1,1,1),(1,1,1)) + return p4 + + +def soft_dilate(img: torch.Tensor) -> torch.Tensor: + """ + Apply soft dilation operation to the input image. + + Soft dilation is achieved by applying a max-pooling operation to the input image. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + + Returns + ------- + torch.Tensor + Dilated image tensor with the same shape as the input. + + Raises + ------ + ValueError + If the input tensor has an unsupported number of dimensions. + + Notes + ----- + - For 5D input, the soft dilation is performed using a 3D max-pooling operation + with a kernel of size (3, 3, 3). + """ + assert len(img.shape)==5 + return F.max_pool3d(img,(3,3,3),(1,1,1),(1,1,1)) + + +def soft_open(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: + """ + Apply soft opening operation to the input image. + + Soft opening is achieved by applying soft erosion followed by soft dilation. + The intention of soft opening is to remove thin membranes from the segmentation. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + separate_pool : bool, optional + If True, perform separate erosion and dilation operations. Default is False. + + Returns + ------- + torch.Tensor + Opened image tensor with the same shape as the input. + + Notes + ----- + - Soft opening is performed by applying soft erosion followed by soft dilation + to the input image. + - For 5D input, separate erosion and dilation can be performed if separate_pool + is True. + """ + return soft_dilate(soft_erode(img, separate_pool=separate_pool)) + + +def soft_skel(img: torch.Tensor, iter_: int, separate_pool: bool = False) -> torch.Tensor: + """ + Compute the soft skeleton of the input image. + + The skeleton is computed by applying soft erosion iteratively to the input image. + In each iteration, the difference between the input image and the "opened" image is + computed and added to the skeleton. + + Reasoning: if there is a difference between the input image and the "opened" image, + there must be a thin membrane skeleton in the input image that was removed by the + opening operation. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + iter_ : int + Number of iterations for skeletonization. + separate_pool : bool, optional + If True, perform separate erosion and dilation operations. Default is False. + + Returns + ------- + torch.Tensor + Soft skeleton image tensor with the same shape as the input. + + Notes + ----- + - Separate erosion can be performed if separate_pool is True. + """ + img1 = soft_open(img, separate_pool=separate_pool) + skel = F.relu(img-img1) + for j in range(iter_): + img = soft_erode(img) + img1 = soft_open(img, separate_pool=separate_pool) + delta = F.relu(img-img1) + skel = skel + F.relu(delta-skel*delta) + return skel + + +def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: + """ + Generate the skeleton of a ground truth segmentation. + + This function takes a ground truth segmentation `gt_seg`, smooths it using a Gaussian filter, + and then computes its soft skeleton using the `soft_skel` function. + + Intention: When using the binary ground truth segmentation for skeletonization, the resulting + skeleton is very patchy and not smooth. When using the smoothed ground truth segmentation, the + resulting skeleton is much smoother and more accurate. + + Parameters + ---------- + gt_seg : torch.Tensor + A torch.Tensor representing the ground truth segmentation. Shape: (B, C, D, H, W) + iterations : int, optional + The number of iterations for skeletonization. Default is 5. + + Returns + ------- + torch.Tensor + A torch.Tensor representing the skeleton of the ground truth segmentation. + + Notes + ----- + - The input `gt_seg` should be a binary segmentation tensor where 1 represents the object of interest. + - The function first smooths the `gt_seg` using a Gaussian filter to enhance the object's structure. + - The skeletonization process is performed using the `soft_skel` function with the specified number of + iterations. + - The resulting skeleton is returned as a binary torch.Tensor where 1 indicates the skeleton points. + """ + gt_smooth = gaussian_filter((gt_seg == 1) * 1., 2) * 1.5 # The Gaussian smoothing parameters are work in progress + skel_gt = soft_skel(torch.from_numpy(gt_smooth)*1.0, iter_=iterations) + return skel_gt + + +def masked_surface_dice(data: torch.Tensor, target: torch.Tensor, ignore_label: int=2, soft_skel_iterations: int=3, + smooth: float=3., binary_prediction: bool=False) -> torch.Tensor: + """ + Compute the surface Dice loss with masking for ignore labels. + + The surface Dice loss measures the similarity between the predicted segmentation's skeleton and + the ground truth segmentation (and vice versa). Labels annotated with "ignore_label" are ignored. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs representing the predicted segmentation. + target : torch.Tensor + Tensor of target labels representing the ground truth segmentation. + ignore_label : int + The label value to be ignored when computing the loss. + soft_skel_iterations : int + Number of iterations for skeletonization in the underlying operations. + smooth : float + Smoothing factor to avoid division by zero. + + Returns + ------- + torch.Tensor + The calculated surface Dice loss. + """ + # Create a mask to ignore the specified label in the target + data = sigmoid(data) + mask = target != ignore_label + + # Compute soft skeletonization + if binary_prediction: + skel_pred = get_GT_skeleton(data.clone(), soft_skel_iterations, separate_pool=False) + else: + skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False) + skel_true = get_GT_skeleton(target.clone(), soft_skel_iterations, separate_pool=False) + + # Mask out ignore labels + skel_pred[~mask] = 0 + skel_true[~mask] = 0 + + # compute surface dice loss + print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") + print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") + print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") + tprec = (torch.sum(torch.multiply(skel_pred, target), dim=0)+smooth)/(torch.sum(skel_pred, dim=0)+smooth) + tsens = (torch.sum(torch.multiply(skel_true, data), dim=0)+smooth)/(torch.sum(skel_true, dim=0)+smooth) + surf_dice_loss = 2.0*(tprec*tsens)/(tprec+tsens) + return surf_dice_loss + + +class IgnoreLabelSurfaceDiceLoss(_Loss): + """ + Surface Dice loss, adding ignore labels. + + Parameters + ---------- + ignore_label : int + The label to ignore when calculating the loss. + reduction : str, optional + Specifies the reduction to apply to the output, by default "mean". + kwargs : dict + Additional keyword arguments. + """ + + def __init__( + self, + ignore_label: int, + soft_skel_iterations: int = 3, + smooth: float = 3., + **kwargs, + ) -> None: + super().__init__() + self.ignore_label = ignore_label + self.soft_skel_iterations = soft_skel_iterations + self.smooth = smooth + + def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute the loss. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs. + target : torch.Tensor + Tensor of target labels. + + Returns + ------- + torch.Tensor + The calculated loss. + """ + # Create a mask to ignore the specified label in the target + surf_dice_score = masked_surface_dice(data=data, target=target, ignore_label=self.ignore_label, soft_skel_iterations=self.soft_skel_iterations, smooth=self.smooth) + surf_dice_loss = 1. - surf_dice_score + + + return surf_dice_loss \ No newline at end of file From df4a3765ddf2277eb32159deeedf816941996022 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 14:06:49 +0100 Subject: [PATCH 02/23] Adjust losses to be compatible with Surface-Dice exclusions --- .../segmentation/training/optim_utils.py | 104 +++++++++++++++++- 1 file changed, 101 insertions(+), 3 deletions(-) diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index 8dbdc6b..c070ceb 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss): def __init__( self, ignore_label: int, - reduction: str = "mean", + reduction: str = "none", lambda_dice: float = 1.0, lambda_ce: float = 1.0, **kwargs, @@ -95,11 +95,23 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: orig_data, target_tensor, reduction="none" ) bce_loss[~mask] = 0.0 - bce_loss = torch.sum(bce_loss) / torch.sum(mask) + #TODO: Check if this is correct: I adjusted the loss to be computed per batch element + bce_loss = torch.sum(bce_loss, dim=0) / torch.sum(mask, dim=0) + print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") + print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") + print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") + print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") dice_loss = self.dice_loss(data, target, mask) + dice_loss = torch.sum(dice_loss, dim=0) / torch.sum(mask, dim=0) + + print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") + print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") + print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") + print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") + print("Also check values!") # Combine the Dice and Cross Entropy losses - combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss + combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss # Combined loss should be per batch element return combined_loss @@ -154,3 +166,89 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: for weight, data, target in zip(self.weights, inputs, targets): loss += weight * self.loss_fn(data, target) return loss + + +class CombinedLoss(_Loss): + """ + Combine multiple loss functions into a single one. + + Parameters + ---------- + losses : List[Callable] + A list of loss function instances. + weights : List[float] + List of weights corresponding to each loss function (must + be of same length as losses). + loss_inclusion_tokens : List[List[str]] + A list of lists containing tokens for each loss function. + Each sublist corresponds to a loss function and contains + tokens for which the loss should be included. + If the list contains "all", then the loss will be included + for all cases. + + Notes + ---------- + IMPORTANT: Loss functions need to return a tensors containing the + loss for each batch element. + + The loss_exclusion_tokens parameter is used to exclude certain + cases from the loss calculation. For example, if the loss_exclusion_tokens + parameter is [["ds1", "ds2"], ["ds1"]], then the first loss function + will be excluded for cases where the dataset label is "ds1" or "ds2", + and the second loss function will be excluded for cases where the + dataset label is "ds1". + + + """ + + def __init__( + self, + losses: list, + weights: list, + loss_inclusion_tokens: list, + **kwargs, + ) -> None: + super().__init__() + self.losses = losses + self.weights = weights + self.loss_inclusion_tokens = loss_inclusion_tokens + + def forward(self, data: torch.Tensor, target: torch.Tensor, ds_label: list) -> torch.Tensor: + """ + Compute the combined loss. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs. + target : torch.Tensor + Tensor of target labels. + ds_label : List[str] + List of dataset labels for each batch element. + + Returns + ------- + torch.Tensor + The calculated combined loss. + """ + + loss = 0. + for loss_idx, (cur_loss, cur_weight, skip_cases) in enumerate(zip(self.losses, self.weights, self.apply_loss_not_for)): + cur_loss_val = cur_loss(data, target) + + print("CHECK WHETHER THE COMBINED LOSS IS CORRECTLY COMPUTED HERE!") + print("CHECK WHETHER THE COMBINED LOSS IS CORRECTLY COMPUTED HERE!") + # Zero out losses for excluded cases + for batch_idx, ds_lab in ds_label: + if "all" in self.loss_inclusion_tokens[loss_idx] or \ + ds_lab in self.loss_inclusion_tokens[loss_idx]: + continue + cur_loss_val[batch_idx] = 0. + + # Aggregate loss + cur_loss_val = cur_loss_val.sum() / ((cur_loss_val != 0.).sum() + 1e-3) + loss += cur_weight * cur_loss_val + + # Normalize loss + loss = loss / self.weights.sum() + return loss From 79859e033b2b5d54f45c93351f7029ed3ba1b88a Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 14:07:48 +0100 Subject: [PATCH 03/23] Adjust training routine and include surface dice loss --- .../segmentation/networks/unet.py | 94 ++++++++++++++----- 1 file changed, 71 insertions(+), 23 deletions(-) diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 7a67310..1f9510a 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -9,17 +9,15 @@ from ..training.metric_utils import masked_accuracy, threshold_function -# from monai.networks.nets import UNet as MonaiUnet -# The normal Monai DynUNet upsamples low-resolution layers to compare directly to GT -# My implementation leaves them in low resolution and compares to down-sampled GT -# Not sure which implementation is better -# To be discussed with Alister & Kevin from ..training.optim_utils import ( DeepSuperVisionLoss, - DynUNetDirectDeepSupervision, # I like to use deep supervision + DynUNetDirectDeepSupervision, IgnoreLabelDiceCELoss, + CombinedLoss, ) +from ..training.surface_dice import IgnoreLabelSurfaceDiceLoss, masked_surface_dice + class SemanticSegmentationUnet(pl.LightningModule): """Implementation of a Unet for semantic segmentation. @@ -62,6 +60,12 @@ class SemanticSegmentationUnet(pl.LightningModule): The maximum number of epochs for training. use_deep_supervision : bool, default=False Whether to use deep supervision. + use_surf_dice : bool, default=False + Whether to use surface dice loss. + surf_dice_weight : float, default=1.0 + The weight for the surface dice loss. + surf_dice_tokens : list, default=[] + The tokens for which to compute the surface dice loss. """ @@ -80,6 +84,9 @@ def __init__( roi_size: Tuple[int, ...] = (160, 160, 160), max_epochs: int = 1000, use_deep_supervision: bool = False, + use_surf_dice=False, + surf_dice_weight=1.0, + surf_dice_tokens=[], ): super().__init__() @@ -102,21 +109,39 @@ def __init__( upsample_kernel_size=(1, 2, 2, 2, 2, 2), filters=channels, res_block=True, - # norm_name="INSTANCE", - # norm=Norm.INSTANCE, # I like the instance normalization better than - # batchnorm in this case, as we will probably have - # only small batch sizes, making BN more noisy deep_supervision=True, deep_supr_num=2, ) + + + ### Build up loss function + + losses = [] + weights = [] + loss_inclusion_tokens = [] ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + losses.append(ignore_dice_loss) + weights.append(1.) + loss_inclusion_tokens.append(["all"]) # Apply to every element + + if use_surf_dice: + ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss(ignore_label=2, soft_skel_iterations=3) + losses.append(ignore_surf_dice_loss) + weights.append(surf_dice_weight) + loss_inclusion_tokens.append(surf_dice_tokens) + + scaled_weights = [entry / sum(weights) for entry in weights] + loss_function = CombinedLoss(losses=losses, weights=scaled_weights, + loss_exclusion_tokens=loss_inclusion_tokens) + self.loss_function = DeepSuperVisionLoss( - ignore_dice_loss, + loss_function, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] if use_deep_supervision else [1.0, 0.0, 0.0, 0.0, 0.0], ) + # validation metric self.dice_metric = DiceMetric( include_background=False, reduction="mean", get_not_nans=False @@ -143,7 +168,9 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] self.running_train_acc = 0.0 + self.running_train_surf_dice = 0.0 self.running_val_acc = 0.0 + self.running_val_surf_dice = 0.0 def forward(self, x) -> torch.Tensor: """Implementation of the forward pass. @@ -180,9 +207,9 @@ def training_step( See the pytorch-lightning module documentation for details. """ - images, labels = batch["image"], batch["label"] + images, labels, ds_label = batch["image"], batch["label"], batch['dataset'] output = self.forward(images) - loss = self.loss_function(output, labels) + loss = self.loss_function(output, labels, ds_label) stats_dict = {"train_loss": loss, "train_number": output[0].shape[0]} self.training_step_outputs.append(stats_dict) @@ -191,6 +218,14 @@ def training_step( * output[0].shape[0] ) + self.running_train_surf_dice += ( + masked_surface_dice(data=output[0].detach(), + target=labels[0].detach(), + ignore_label=2., + soft_skel_iterations=3, + smooth=1.) * output[0].shape[0] + ) + return {"loss": loss} def on_train_epoch_end(self): @@ -207,13 +242,17 @@ def on_train_epoch_end(self): mean_train_loss = torch.tensor(train_loss / num_items) mean_train_acc = self.running_train_acc / num_items + mean_train_surf_dice = self.running_train_surf_dice / num_items self.running_train_acc = 0.0 - self.log("train_loss", mean_train_loss) # , batch_size=num_items) - self.log("train_acc", mean_train_acc) # , batch_size=num_items) + self.running_train_surf_dice = 0.0 + self.log("train_loss", mean_train_loss) + self.log("train_acc", mean_train_acc) + self.log("train_surf_dice", mean_train_surf_dice) self.training_step_outputs = [] print("EPOCH Training loss", mean_train_loss.item()) print("EPOCH Training acc", mean_train_acc.item()) + print("EPOCH Training surface dice", mean_train_surf_dice.item()) # Accuracy not the most informative metric, but a good sanity check return {"train_loss": mean_train_loss} @@ -224,13 +263,9 @@ def validation_step(self, batch, batch_idx): using a sliding window. See the pytorch-lightning module documentation for details. """ - images, labels = batch[self.image_key], batch[self.label_key] - # sw_batch_size = 4 - # outputs = sliding_window_inference( - # images, self.roi_size, sw_batch_size, self.forward - # ) + images, labels, ds_label = batch["image"], batch["label"], batch["dataset"] outputs = self.forward(images) - loss = self.loss_function(outputs, labels) + loss = self.loss_function(outputs, labels, ds_label) # Cloning and adjusting preds & labels for Dice. # Could also use the same labels, but maybe we want to @@ -254,6 +289,15 @@ def validation_step(self, batch, batch_idx): ) * outputs[0].shape[0] ) + + self.running_val_surf_dice += ( + masked_surface_dice(data=outputs[0].detach(), + target=labels[0].detach(), + ignore_label=2., + soft_skel_iterations=3, + smooth=1.) * outputs[0].shape[0] + ) + return stats_dict def on_validation_epoch_end(self): @@ -270,13 +314,17 @@ def on_validation_epoch_end(self): mean_val_loss = torch.tensor(val_loss / num_items) mean_val_acc = self.running_val_acc / num_items + mean_val_surf_dice = self.running_val_surf_dice / num_items self.running_val_acc = 0.0 - self.log("val_loss", mean_val_loss), # batch_size=num_items) - self.log("val_dice", mean_val_dice) # , batch_size=num_items) + self.running_val_surf_dice = 0.0 + self.log("val_loss", mean_val_loss), + self.log("val_dice", mean_val_dice) + self.log("val_surf_dice", mean_val_surf_dice) self.log("val_accuracy", mean_val_acc) self.validation_step_outputs = [] print("EPOCH Validation loss", mean_val_loss.item()) print("EPOCH Validation dice", mean_val_dice) + print("EPOCH Validation surface dice", mean_val_surf_dice.item()) print("EPOCH Validation acc", mean_val_acc.item()) return {"val_loss": mean_val_loss, "val_metric": mean_val_dice} From 4868e983f79d401e82a4cd2ca06f06049f145bbd Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 14:08:49 +0100 Subject: [PATCH 04/23] Add dataset labels to dataloading --- .../segmentation/dataloading/memseg_dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py index d856883..727b23a 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py @@ -102,6 +102,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: "label": np.expand_dims(self.labels[idx], 0), } idx_dict = self.transforms(idx_dict) + idx_dict["dataset"] = self.dataset_labels[idx] return idx_dict def __len__(self) -> int: @@ -126,6 +127,7 @@ def load_data(self) -> None: print("Loading images into dataset.") self.imgs = [] self.labels = [] + self.dataset_labels = [] for entry in self.data_paths: label = read_nifti( entry[1] @@ -137,6 +139,8 @@ def load_data(self) -> None: img = np.transpose(img, (1, 2, 0)) self.imgs.append(img) self.labels.append(label) + self.dataset_labels.append(get_dataset_token(entry[0])) + def initialize_imgs_paths(self) -> None: """ @@ -190,3 +194,9 @@ def test(self, test_folder: str, num_files: int = 20) -> None: os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"), test_sample["label"][1][0, :, :, num_mask], ) + + +def get_dataset_token(patch_name): + basename = os.path.basename(patch_name) + dataset_token = basename.split("_")[0] + return dataset_token \ No newline at end of file From 7bf110b31ad0e8a7604c4f17103e3dd7f6e703c1 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 14:09:26 +0100 Subject: [PATCH 05/23] Pass Surface-Dice arguments to training routine --- src/membrain_seg/segmentation/train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 2f077f2..7c059d7 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -24,6 +24,9 @@ def train( use_deep_supervision: bool = False, project_name: str = "membrain-seg_v0", sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = ["all"], ): """ Train the model on the specified data. @@ -67,7 +70,11 @@ def train( # Set up the model model = SemanticSegmentationUnet( - max_epochs=max_epochs, use_deep_supervision=use_deep_supervision + max_epochs=max_epochs, + use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, ) project_name = project_name From 860aa1cacf2d3fcd8a8e985fe4a76558601e5e8a Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 14:15:00 +0100 Subject: [PATCH 06/23] Update CLI to include advanced options for Surface-Dice --- .../segmentation/cli/train_cli.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py index ac3d394..c15c759 100644 --- a/src/membrain_seg/segmentation/cli/train_cli.py +++ b/src/membrain_seg/segmentation/cli/train_cli.py @@ -84,6 +84,18 @@ def train_advanced( but also severely increases training time.\ Pass "True" or "False".', ), + use_surface_dice: bool = Option( # noqa: B008 + True, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' + ), + surface_dice_weight: float = Option( # noqa: B008 + 1.0, help='Scaling factor for the Surface-Dice loss. ' + ), + surface_dice_tokens: list = Option( # noqa: B008 + ["all"], + help='List of tokens to use for the Surface-Dice loss. \ + Pass a list of strings.\ + For example, ["all", "membrane"]', + ), use_deep_supervision: bool = Option( # noqa: B008 True, help='Whether to use deep supervision. Pass "True" or "False".' ), @@ -119,6 +131,12 @@ def train_advanced( If set to False, data augmentation still happens, but not as frequently. More data augmentation can lead to a better performance, but also increases the training time substantially. + use_surface_dice : bool + Determines whether to use Surface-Dice loss, by default True. + surface_dice_weight : float + Scaling factor for the Surface-Dice loss, by default 1.0. + surface_dice_tokens : list + List of tokens to use for the Surface-Dice loss, by default ["all"]. use_deep_supervision : bool Determines whether to use deep supervision, by default True. project_name : str @@ -140,6 +158,9 @@ def train_advanced( max_epochs=max_epochs, aug_prob_to_one=aug_prob_to_one, use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surface_dice, + surf_dice_weight=surface_dice_weight, + surf_dice_tokens=surface_dice_tokens, project_name=project_name, sub_name=sub_name, ) From bd543f54ba1ebab9d46a03bad03c971eb21e497c Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Sun, 31 Dec 2023 17:54:04 +0100 Subject: [PATCH 07/23] precommit formatting --- .../segmentation/cli/train_cli.py | 2 +- .../dataloading/memseg_dataset.py | 17 +- .../segmentation/networks/unet.py | 71 +++--- src/membrain_seg/segmentation/train.py | 16 +- .../segmentation/training/optim_utils.py | 83 +++++-- .../segmentation/training/surface_dice.py | 227 ++++++++++-------- 6 files changed, 255 insertions(+), 161 deletions(-) diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py index c15c759..47b3770 100644 --- a/src/membrain_seg/segmentation/cli/train_cli.py +++ b/src/membrain_seg/segmentation/cli/train_cli.py @@ -88,7 +88,7 @@ def train_advanced( True, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' ), surface_dice_weight: float = Option( # noqa: B008 - 1.0, help='Scaling factor for the Surface-Dice loss. ' + 1.0, help="Scaling factor for the Surface-Dice loss. " ), surface_dice_tokens: list = Option( # noqa: B008 ["all"], diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py index 727b23a..9af7c0f 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py @@ -141,7 +141,6 @@ def load_data(self) -> None: self.labels.append(label) self.dataset_labels.append(get_dataset_token(entry[0])) - def initialize_imgs_paths(self) -> None: """ Initializes the list of paths to image-label pairs. @@ -197,6 +196,20 @@ def test(self, test_folder: str, num_files: int = 20) -> None: def get_dataset_token(patch_name): + """ + Get the dataset token from the patch name. + + Parameters + ---------- + patch_name : str + The name of the patch. + + Returns + ------- + str + The dataset token. + + """ basename = os.path.basename(patch_name) dataset_token = basename.split("_")[0] - return dataset_token \ No newline at end of file + return dataset_token diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 1f9510a..96734f4 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -8,14 +8,12 @@ from monai.transforms import AsDiscrete, Compose, EnsureType, Lambda from ..training.metric_utils import masked_accuracy, threshold_function - from ..training.optim_utils import ( + CombinedLoss, DeepSuperVisionLoss, - DynUNetDirectDeepSupervision, + DynUNetDirectDeepSupervision, IgnoreLabelDiceCELoss, - CombinedLoss, ) - from ..training.surface_dice import IgnoreLabelSurfaceDiceLoss, masked_surface_dice @@ -84,9 +82,9 @@ def __init__( roi_size: Tuple[int, ...] = (160, 160, 160), max_epochs: int = 1000, use_deep_supervision: bool = False, - use_surf_dice=False, - surf_dice_weight=1.0, - surf_dice_tokens=[], + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, ): super().__init__() @@ -113,7 +111,6 @@ def __init__( deep_supr_num=2, ) - ### Build up loss function losses = [] @@ -121,19 +118,26 @@ def __init__( loss_inclusion_tokens = [] ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") losses.append(ignore_dice_loss) - weights.append(1.) - loss_inclusion_tokens.append(["all"]) # Apply to every element + weights.append(1.0) + loss_inclusion_tokens.append(["all"]) # Apply to every element if use_surf_dice: - ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss(ignore_label=2, soft_skel_iterations=3) + if surf_dice_tokens is None: + surf_dice_tokens = ["all"] + ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss( + ignore_label=2, soft_skel_iterations=3 + ) losses.append(ignore_surf_dice_loss) weights.append(surf_dice_weight) loss_inclusion_tokens.append(surf_dice_tokens) scaled_weights = [entry / sum(weights) for entry in weights] - loss_function = CombinedLoss(losses=losses, weights=scaled_weights, - loss_exclusion_tokens=loss_inclusion_tokens) - + loss_function = CombinedLoss( + losses=losses, + weights=scaled_weights, + loss_exclusion_tokens=loss_inclusion_tokens, + ) + self.loss_function = DeepSuperVisionLoss( loss_function, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] @@ -141,7 +145,6 @@ def __init__( else [1.0, 0.0, 0.0, 0.0, 0.0], ) - # validation metric self.dice_metric = DiceMetric( include_background=False, reduction="mean", get_not_nans=False @@ -207,7 +210,7 @@ def training_step( See the pytorch-lightning module documentation for details. """ - images, labels, ds_label = batch["image"], batch["label"], batch['dataset'] + images, labels, ds_label = batch["image"], batch["label"], batch["dataset"] output = self.forward(images) loss = self.loss_function(output, labels, ds_label) @@ -219,11 +222,14 @@ def training_step( ) self.running_train_surf_dice += ( - masked_surface_dice(data=output[0].detach(), - target=labels[0].detach(), - ignore_label=2., - soft_skel_iterations=3, - smooth=1.) * output[0].shape[0] + masked_surface_dice( + data=output[0].detach(), + target=labels[0].detach(), + ignore_label=2.0, + soft_skel_iterations=3, + smooth=1.0, + ) + * output[0].shape[0] ) return {"loss": loss} @@ -245,9 +251,9 @@ def on_train_epoch_end(self): mean_train_surf_dice = self.running_train_surf_dice / num_items self.running_train_acc = 0.0 self.running_train_surf_dice = 0.0 - self.log("train_loss", mean_train_loss) - self.log("train_acc", mean_train_acc) - self.log("train_surf_dice", mean_train_surf_dice) + self.log("train_loss", mean_train_loss) + self.log("train_acc", mean_train_acc) + self.log("train_surf_dice", mean_train_surf_dice) self.training_step_outputs = [] print("EPOCH Training loss", mean_train_loss.item()) @@ -291,11 +297,14 @@ def validation_step(self, batch, batch_idx): ) self.running_val_surf_dice += ( - masked_surface_dice(data=outputs[0].detach(), - target=labels[0].detach(), - ignore_label=2., - soft_skel_iterations=3, - smooth=1.) * outputs[0].shape[0] + masked_surface_dice( + data=outputs[0].detach(), + target=labels[0].detach(), + ignore_label=2.0, + soft_skel_iterations=3, + smooth=1.0, + ) + * outputs[0].shape[0] ) return stats_dict @@ -317,8 +326,8 @@ def on_validation_epoch_end(self): mean_val_surf_dice = self.running_val_surf_dice / num_items self.running_val_acc = 0.0 self.running_val_surf_dice = 0.0 - self.log("val_loss", mean_val_loss), - self.log("val_dice", mean_val_dice) + self.log("val_loss", mean_val_loss), + self.log("val_dice", mean_val_dice) self.log("val_surf_dice", mean_val_surf_dice) self.log("val_accuracy", mean_val_acc) diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 7c059d7..331ced1 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -26,7 +26,7 @@ def train( sub_name: str = "1", use_surf_dice: bool = False, surf_dice_weight: float = 1.0, - surf_dice_tokens: list = ["all"], + surf_dice_tokens: list = None, ): """ Train the model on the specified data. @@ -55,6 +55,12 @@ def train( Name of the project for logging purposes. sub_name : str, optional Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. Returns ------- @@ -70,10 +76,10 @@ def train( # Set up the model model = SemanticSegmentationUnet( - max_epochs=max_epochs, - use_deep_supervision=use_deep_supervision, - use_surf_dice=use_surf_dice, - surf_dice_weight=surf_dice_weight, + max_epochs=max_epochs, + use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, surf_dice_tokens=surf_dice_tokens, ) diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index c070ceb..ab050e7 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -95,23 +95,49 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: orig_data, target_tensor, reduction="none" ) bce_loss[~mask] = 0.0 - #TODO: Check if this is correct: I adjusted the loss to be computed per batch element + # TODO: Check if this is correct: I adjusted the loss to be + # computed per batch element bce_loss = torch.sum(bce_loss, dim=0) / torch.sum(mask, dim=0) - print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") - print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") - print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") - print(bce_loss.shape, "<------------- Please check the shape of the BCE loss here!!") + print( + bce_loss.shape, + "<------------- Please check the shape of the BCE loss here!!", + ) + print( + bce_loss.shape, + "<------------- Please check the shape of the BCE loss here!!", + ) + print( + bce_loss.shape, + "<------------- Please check the shape of the BCE loss here!!", + ) + print( + bce_loss.shape, + "<------------- Please check the shape of the BCE loss here!!", + ) dice_loss = self.dice_loss(data, target, mask) dice_loss = torch.sum(dice_loss, dim=0) / torch.sum(mask, dim=0) - - print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") - print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") - print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") - print(bce_loss.shape, "<------------- Please check the shape of the Dice loss here!!") + print( + bce_loss.shape, + "<------------- Please check the shape of the Dice loss here!!", + ) + print( + bce_loss.shape, + "<------------- Please check the shape of the Dice loss here!!", + ) + print( + bce_loss.shape, + "<------------- Please check the shape of the Dice loss here!!", + ) + print( + bce_loss.shape, + "<------------- Please check the shape of the Dice loss here!!", + ) print("Also check values!") # Combine the Dice and Cross Entropy losses - combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss # Combined loss should be per batch element + combined_loss = ( + self.lambda_dice * dice_loss + self.lambda_ce * bce_loss + ) # Combined loss should be per batch element return combined_loss @@ -177,19 +203,19 @@ class CombinedLoss(_Loss): losses : List[Callable] A list of loss function instances. weights : List[float] - List of weights corresponding to each loss function (must + List of weights corresponding to each loss function (must be of same length as losses). loss_inclusion_tokens : List[List[str]] A list of lists containing tokens for each loss function. - Each sublist corresponds to a loss function and contains + Each sublist corresponds to a loss function and contains tokens for which the loss should be included. If the list contains "all", then the loss will be included for all cases. Notes - ---------- - IMPORTANT: Loss functions need to return a tensors containing the - loss for each batch element. + ----- + IMPORTANT: Loss functions need to return a tensors containing the + loss for each batch element. The loss_exclusion_tokens parameter is used to exclude certain cases from the loss calculation. For example, if the loss_exclusion_tokens @@ -198,7 +224,7 @@ class CombinedLoss(_Loss): and the second loss function will be excluded for cases where the dataset label is "ds1". - + """ def __init__( @@ -213,7 +239,9 @@ def __init__( self.weights = weights self.loss_inclusion_tokens = loss_inclusion_tokens - def forward(self, data: torch.Tensor, target: torch.Tensor, ds_label: list) -> torch.Tensor: + def forward( + self, data: torch.Tensor, target: torch.Tensor, ds_label: list + ) -> torch.Tensor: """ Compute the combined loss. @@ -231,24 +259,27 @@ def forward(self, data: torch.Tensor, target: torch.Tensor, ds_label: list) -> t torch.Tensor The calculated combined loss. """ - - loss = 0. - for loss_idx, (cur_loss, cur_weight, skip_cases) in enumerate(zip(self.losses, self.weights, self.apply_loss_not_for)): + loss = 0.0 + for loss_idx, (cur_loss, cur_weight, _skip_cases) in enumerate( + zip(self.losses, self.weights, self.apply_loss_not_for) + ): cur_loss_val = cur_loss(data, target) print("CHECK WHETHER THE COMBINED LOSS IS CORRECTLY COMPUTED HERE!") print("CHECK WHETHER THE COMBINED LOSS IS CORRECTLY COMPUTED HERE!") # Zero out losses for excluded cases for batch_idx, ds_lab in ds_label: - if "all" in self.loss_inclusion_tokens[loss_idx] or \ - ds_lab in self.loss_inclusion_tokens[loss_idx]: + if ( + "all" in self.loss_inclusion_tokens[loss_idx] + or ds_lab in self.loss_inclusion_tokens[loss_idx] + ): continue - cur_loss_val[batch_idx] = 0. + cur_loss_val[batch_idx] = 0.0 # Aggregate loss - cur_loss_val = cur_loss_val.sum() / ((cur_loss_val != 0.).sum() + 1e-3) + cur_loss_val = cur_loss_val.sum() / ((cur_loss_val != 0.0).sum() + 1e-3) loss += cur_weight * cur_loss_val - + # Normalize loss loss = loss / self.weights.sum() return loss diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py index f0c6868..deb3857 100644 --- a/src/membrain_seg/segmentation/training/surface_dice.py +++ b/src/membrain_seg/segmentation/training/surface_dice.py @@ -1,13 +1,18 @@ """ -Adapted from: clDice - A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation +Surface Dice implementation. + +Adapted from: clDice - A Novel Topology-Preserving Loss Function for Tubular +Structure Segmentation Original Authors: Johannes C. Paetzold and Suprosanna Shit -Sources: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py +Sources: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/ + soft_skeleton.py https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py -License: MIT License +License: MIT License. -The following code is a modification of the original clDice implementation. Modifications were made to -include additional functionality and integrate with new project requirements. The original license and -copyright notice are provided below. +The following code is a modification of the original clDice implementation. +Modifications were made to include additional functionality and integrate +with new project requirements. The original license and copyright notice are +provided below. MIT License @@ -34,11 +39,9 @@ import torch import torch.nn.functional as F -from .soft_skeleton import soft_skel +from scipy.ndimage import gaussian_filter from torch.nn.functional import sigmoid from torch.nn.modules.loss import _Loss -from scipy.ndimage import gaussian_filter - def soft_erode(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: @@ -67,18 +70,18 @@ def soft_erode(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: Notes ----- - - The soft erosion can be performed with separate 3D min-pooling operations - along different axes if separate_pool is True, or with a single - 3D min-pooling operation with a kernel of size (3, 3, 3) if + - The soft erosion can be performed with separate 3D min-pooling operations + along different axes if separate_pool is True, or with a single + 3D min-pooling operation with a kernel of size (3, 3, 3) if separate_pool is False. """ - assert len(img.shape)==5 + assert len(img.shape) == 5 if separate_pool: - p1 = -F.max_pool3d(-img,(3,1,1),(1,1,1),(1,0,0)) - p2 = -F.max_pool3d(-img,(1,3,1),(1,1,1),(0,1,0)) - p3 = -F.max_pool3d(-img,(1,1,3),(1,1,1),(0,0,1)) + p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) + p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) + p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) return torch.min(torch.min(p1, p2), p3) - p4 = -F.max_pool3d(-img,(3,3,3),(1,1,1),(1,1,1)) + p4 = -F.max_pool3d(-img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) return p4 @@ -105,11 +108,11 @@ def soft_dilate(img: torch.Tensor) -> torch.Tensor: Notes ----- - - For 5D input, the soft dilation is performed using a 3D max-pooling operation + - For 5D input, the soft dilation is performed using a 3D max-pooling operation with a kernel of size (3, 3, 3). """ - assert len(img.shape)==5 - return F.max_pool3d(img,(3,3,3),(1,1,1),(1,1,1)) + assert len(img.shape) == 5 + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) def soft_open(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: @@ -133,15 +136,17 @@ def soft_open(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: Notes ----- - - Soft opening is performed by applying soft erosion followed by soft dilation + - Soft opening is performed by applying soft erosion followed by soft dilation to the input image. - - For 5D input, separate erosion and dilation can be performed if separate_pool + - For 5D input, separate erosion and dilation can be performed if separate_pool is True. """ return soft_dilate(soft_erode(img, separate_pool=separate_pool)) -def soft_skel(img: torch.Tensor, iter_: int, separate_pool: bool = False) -> torch.Tensor: +def soft_skel( + img: torch.Tensor, iter_: int, separate_pool: bool = False +) -> torch.Tensor: """ Compute the soft skeleton of the input image. @@ -150,7 +155,7 @@ def soft_skel(img: torch.Tensor, iter_: int, separate_pool: bool = False) -> tor computed and added to the skeleton. Reasoning: if there is a difference between the input image and the "opened" image, - there must be a thin membrane skeleton in the input image that was removed by the + there must be a thin membrane skeleton in the input image that was removed by the opening operation. Parameters @@ -160,7 +165,8 @@ def soft_skel(img: torch.Tensor, iter_: int, separate_pool: bool = False) -> tor iter_ : int Number of iterations for skeletonization. separate_pool : bool, optional - If True, perform separate erosion and dilation operations. Default is False. + If True, perform separate erosion and dilation operations. + Default is False. Returns ------- @@ -171,13 +177,13 @@ def soft_skel(img: torch.Tensor, iter_: int, separate_pool: bool = False) -> tor ----- - Separate erosion can be performed if separate_pool is True. """ - img1 = soft_open(img, separate_pool=separate_pool) - skel = F.relu(img-img1) - for j in range(iter_): - img = soft_erode(img) - img1 = soft_open(img, separate_pool=separate_pool) - delta = F.relu(img-img1) - skel = skel + F.relu(delta-skel*delta) + img1 = soft_open(img, separate_pool=separate_pool) + skel = F.relu(img - img1) + for _j in range(iter_): + img = soft_erode(img) + img1 = soft_open(img, separate_pool=separate_pool) + delta = F.relu(img - img1) + skel = skel + F.relu(delta - skel * delta) return skel @@ -185,17 +191,19 @@ def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: """ Generate the skeleton of a ground truth segmentation. - This function takes a ground truth segmentation `gt_seg`, smooths it using a Gaussian filter, - and then computes its soft skeleton using the `soft_skel` function. + This function takes a ground truth segmentation `gt_seg`, smooths it using a + Gaussian filter, and then computes its soft skeleton using the `soft_skel` function. - Intention: When using the binary ground truth segmentation for skeletonization, the resulting - skeleton is very patchy and not smooth. When using the smoothed ground truth segmentation, the - resulting skeleton is much smoother and more accurate. + Intention: When using the binary ground truth segmentation for skeletonization, + the resulting skeleton is very patchy and not smooth. When using the smoothed + ground truth segmentation, the resulting skeleton is much smoother and more + accurate. Parameters ---------- gt_seg : torch.Tensor - A torch.Tensor representing the ground truth segmentation. Shape: (B, C, D, H, W) + A torch.Tensor representing the ground truth segmentation. + Shape: (B, C, D, H, W) iterations : int, optional The number of iterations for skeletonization. Default is 5. @@ -206,66 +214,88 @@ def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: Notes ----- - - The input `gt_seg` should be a binary segmentation tensor where 1 represents the object of interest. - - The function first smooths the `gt_seg` using a Gaussian filter to enhance the object's structure. - - The skeletonization process is performed using the `soft_skel` function with the specified number of - iterations. - - The resulting skeleton is returned as a binary torch.Tensor where 1 indicates the skeleton points. + - The input `gt_seg` should be a binary segmentation tensor where 1 represents the + object of interest. + - The function first smooths the `gt_seg` using a Gaussian filter to enhance the + object's structure. + - The skeletonization process is performed using the `soft_skel` function with the + specified number of iterations. + - The resulting skeleton is returned as a binary torch.Tensor where 1 indicates the + skeleton points. """ - gt_smooth = gaussian_filter((gt_seg == 1) * 1., 2) * 1.5 # The Gaussian smoothing parameters are work in progress - skel_gt = soft_skel(torch.from_numpy(gt_smooth)*1.0, iter_=iterations) + gt_smooth = ( + gaussian_filter((gt_seg == 1) * 1.0, 2) * 1.5 + ) # The Gaussian smoothing parameters are work in progress + skel_gt = soft_skel(torch.from_numpy(gt_smooth) * 1.0, iter_=iterations) return skel_gt -def masked_surface_dice(data: torch.Tensor, target: torch.Tensor, ignore_label: int=2, soft_skel_iterations: int=3, - smooth: float=3., binary_prediction: bool=False) -> torch.Tensor: - """ - Compute the surface Dice loss with masking for ignore labels. +def masked_surface_dice( + data: torch.Tensor, + target: torch.Tensor, + ignore_label: int = 2, + soft_skel_iterations: int = 3, + smooth: float = 3.0, + binary_prediction: bool = False, +) -> torch.Tensor: + """ + Compute the surface Dice loss with masking for ignore labels. - The surface Dice loss measures the similarity between the predicted segmentation's skeleton and - the ground truth segmentation (and vice versa). Labels annotated with "ignore_label" are ignored. + The surface Dice loss measures the similarity between the predicted segmentation's + skeleton and the ground truth segmentation (and vice versa). Labels annotated with + "ignore_label" are ignored. - Parameters - ---------- - data : torch.Tensor - Tensor of model outputs representing the predicted segmentation. - target : torch.Tensor - Tensor of target labels representing the ground truth segmentation. - ignore_label : int - The label value to be ignored when computing the loss. - soft_skel_iterations : int - Number of iterations for skeletonization in the underlying operations. - smooth : float - Smoothing factor to avoid division by zero. + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs representing the predicted segmentation. + target : torch.Tensor + Tensor of target labels representing the ground truth segmentation. + ignore_label : int + The label value to be ignored when computing the loss. + soft_skel_iterations : int + Number of iterations for skeletonization in the underlying operations. + smooth : float + Smoothing factor to avoid division by zero. + binary_prediction : bool + If True, the predicted segmentation is assumed to be binary. Default is False. - Returns - ------- - torch.Tensor - The calculated surface Dice loss. - """ - # Create a mask to ignore the specified label in the target - data = sigmoid(data) - mask = target != ignore_label - - # Compute soft skeletonization - if binary_prediction: - skel_pred = get_GT_skeleton(data.clone(), soft_skel_iterations, separate_pool=False) - else: - skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False) - skel_true = get_GT_skeleton(target.clone(), soft_skel_iterations, separate_pool=False) - - # Mask out ignore labels - skel_pred[~mask] = 0 - skel_true[~mask] = 0 - - # compute surface dice loss - print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") - print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") - print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") - tprec = (torch.sum(torch.multiply(skel_pred, target), dim=0)+smooth)/(torch.sum(skel_pred, dim=0)+smooth) - tsens = (torch.sum(torch.multiply(skel_true, data), dim=0)+smooth)/(torch.sum(skel_true, dim=0)+smooth) - surf_dice_loss = 2.0*(tprec*tsens)/(tprec+tsens) - return surf_dice_loss + Returns + ------- + torch.Tensor + The calculated surface Dice loss. + """ + # Create a mask to ignore the specified label in the target + data = sigmoid(data) + mask = target != ignore_label + + # Compute soft skeletonization + if binary_prediction: + skel_pred = get_GT_skeleton( + data.clone(), soft_skel_iterations, separate_pool=False + ) + else: + skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False) + skel_true = get_GT_skeleton( + target.clone(), soft_skel_iterations, separate_pool=False + ) + + # Mask out ignore labels + skel_pred[~mask] = 0 + skel_true[~mask] = 0 + + # compute surface dice loss + print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") + print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") + print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") + tprec = (torch.sum(torch.multiply(skel_pred, target), dim=0) + smooth) / ( + torch.sum(skel_pred, dim=0) + smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, data), dim=0) + smooth) / ( + torch.sum(skel_true, dim=0) + smooth + ) + surf_dice_loss = 2.0 * (tprec * tsens) / (tprec + tsens) + return surf_dice_loss class IgnoreLabelSurfaceDiceLoss(_Loss): @@ -286,7 +316,7 @@ def __init__( self, ignore_label: int, soft_skel_iterations: int = 3, - smooth: float = 3., + smooth: float = 3.0, **kwargs, ) -> None: super().__init__() @@ -311,8 +341,13 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: The calculated loss. """ # Create a mask to ignore the specified label in the target - surf_dice_score = masked_surface_dice(data=data, target=target, ignore_label=self.ignore_label, soft_skel_iterations=self.soft_skel_iterations, smooth=self.smooth) - surf_dice_loss = 1. - surf_dice_score - + surf_dice_score = masked_surface_dice( + data=data, + target=target, + ignore_label=self.ignore_label, + soft_skel_iterations=self.soft_skel_iterations, + smooth=self.smooth, + ) + surf_dice_loss = 1.0 - surf_dice_score - return surf_dice_loss \ No newline at end of file + return surf_dice_loss From a3c0ce2c8007a68e283dd2744835f68a93e79c35 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:18:34 +0100 Subject: [PATCH 08/23] make list readable by passing argument multiple times --- .../segmentation/cli/train_cli.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py index 47b3770..8cc6132 100644 --- a/src/membrain_seg/segmentation/cli/train_cli.py +++ b/src/membrain_seg/segmentation/cli/train_cli.py @@ -1,4 +1,7 @@ +from typing import List, Optional + from typer import Option +from typing_extensions import Annotated from ..train import train as _train from .cli import OPTION_PROMPT_KWARGS as PKWARGS @@ -70,7 +73,7 @@ def train_advanced( help="Batch size for training.", ), num_workers: int = Option( # noqa: B008 - 1, + 8, help="Number of worker threads for loading data", ), max_epochs: int = Option( # noqa: B008 @@ -85,17 +88,21 @@ def train_advanced( Pass "True" or "False".', ), use_surface_dice: bool = Option( # noqa: B008 - True, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' + False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' ), surface_dice_weight: float = Option( # noqa: B008 1.0, help="Scaling factor for the Surface-Dice loss. " ), - surface_dice_tokens: list = Option( # noqa: B008 - ["all"], - help='List of tokens to use for the Surface-Dice loss. \ - Pass a list of strings.\ - For example, ["all", "membrane"]', - ), + surface_dice_tokens: Annotated[ + Optional[List[str]], + Option( + help='List of tokens to \ + use for the Surface-Dice loss. \ + Pass tokens separately:\ + For example, train_advanced --surface_dice_tokens "ds1" \ + --surface_dice_tokens "ds2"' + ), + ] = None, use_deep_supervision: bool = Option( # noqa: B008 True, help='Whether to use deep supervision. Pass "True" or "False".' ), From 96607908092fbe39046437321acb6f3c3b8980e5 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:21:33 +0100 Subject: [PATCH 09/23] remove redundant import --- src/membrain_seg/segmentation/dataloading/memseg_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py index 9af7c0f..9b66dda 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py @@ -1,7 +1,6 @@ import os from typing import Dict -# from skimage import io import imageio as io import numpy as np from torch.utils.data import Dataset From 94a1c81c09924867e5fc4fb75e49dd72299bf153 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:26:55 +0100 Subject: [PATCH 10/23] Compatibility with updated masked_surface_dice function --- src/membrain_seg/segmentation/networks/unet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 96734f4..1daac61 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -112,7 +112,6 @@ def __init__( ) ### Build up loss function - losses = [] weights = [] loss_inclusion_tokens = [] @@ -125,17 +124,18 @@ def __init__( if surf_dice_tokens is None: surf_dice_tokens = ["all"] ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss( - ignore_label=2, soft_skel_iterations=3 + ignore_label=2, soft_skel_iterations=5 ) losses.append(ignore_surf_dice_loss) weights.append(surf_dice_weight) loss_inclusion_tokens.append(surf_dice_tokens) scaled_weights = [entry / sum(weights) for entry in weights] + loss_function = CombinedLoss( losses=losses, weights=scaled_weights, - loss_exclusion_tokens=loss_inclusion_tokens, + loss_inclusion_tokens=loss_inclusion_tokens, ) self.loss_function = DeepSuperVisionLoss( @@ -220,14 +220,14 @@ def training_step( masked_accuracy(output[0], labels[0], ignore_label=2.0, threshold_value=0.0) * output[0].shape[0] ) - self.running_train_surf_dice += ( masked_surface_dice( data=output[0].detach(), target=labels[0].detach(), ignore_label=2.0, - soft_skel_iterations=3, + soft_skel_iterations=5, smooth=1.0, + reduction="mean", ) * output[0].shape[0] ) @@ -301,12 +301,12 @@ def validation_step(self, batch, batch_idx): data=outputs[0].detach(), target=labels[0].detach(), ignore_label=2.0, - soft_skel_iterations=3, + soft_skel_iterations=5, smooth=1.0, + reduction="mean", ) * outputs[0].shape[0] ) - return stats_dict def on_validation_epoch_end(self): From 722a5eded16c55aac672c38675bd76fbc4f53c86 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:29:42 +0100 Subject: [PATCH 11/23] Add training summary and remove wandb logging --- src/membrain_seg/segmentation/train.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 331ced1..9c576f5 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -9,6 +9,9 @@ MemBrainSegDataModule, ) from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.training_param_summary import ( + print_training_parameters, +) warnings.filterwarnings("ignore", category=UserWarning, module="torch._tensor") warnings.filterwarnings("ignore", category=UserWarning, module="monai.data") @@ -66,6 +69,20 @@ def train( ------- None """ + print_training_parameters( + data_dir=data_dir, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, + ) # Set up the data module data_module = MemBrainSegDataModule( data_dir=data_dir, @@ -86,9 +103,6 @@ def train( project_name = project_name checkpointing_name = project_name + "_" + sub_name # Set up logging - wandb_logger = pl_loggers.WandbLogger( - project=project_name, log_model=False, save_code=True - ) csv_logger = pl_loggers.CSVLogger(log_dir) # Set up model checkpointing @@ -119,7 +133,7 @@ def on_epoch_start(self, trainer, pl_module): # Set up the trainer trainer = pl.Trainer( precision="16-mixed", - logger=[csv_logger, wandb_logger], + logger=[csv_logger], callbacks=[ checkpoint_callback_val_loss, checkpoint_callback_regular, From 044a7d698d296c135dd5c54f8096adcef0d5d58b Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:31:35 +0100 Subject: [PATCH 12/23] remove reduntant print statements and include ds_labels into for-loop --- .../segmentation/training/optim_utils.py | 72 ++++++------------- 1 file changed, 23 insertions(+), 49 deletions(-) diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index ab050e7..bd807ab 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -97,47 +97,20 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: bce_loss[~mask] = 0.0 # TODO: Check if this is correct: I adjusted the loss to be # computed per batch element - bce_loss = torch.sum(bce_loss, dim=0) / torch.sum(mask, dim=0) - print( - bce_loss.shape, - "<------------- Please check the shape of the BCE loss here!!", + bce_loss = torch.sum(bce_loss, dim=(1, 2, 3, 4)) / torch.sum( + mask, dim=(1, 2, 3, 4) ) - print( - bce_loss.shape, - "<------------- Please check the shape of the BCE loss here!!", - ) - print( - bce_loss.shape, - "<------------- Please check the shape of the BCE loss here!!", - ) - print( - bce_loss.shape, - "<------------- Please check the shape of the BCE loss here!!", - ) - dice_loss = self.dice_loss(data, target, mask) - dice_loss = torch.sum(dice_loss, dim=0) / torch.sum(mask, dim=0) + # Compute Dice loss separately for each batch element + dice_loss = torch.zeros_like(bce_loss) + for batch_idx in range(data.shape[0]): + dice_loss[batch_idx] = self.dice_loss( + data[batch_idx].unsqueeze(0), + target[batch_idx].unsqueeze(0), + mask[batch_idx].unsqueeze(0), + ) - print( - bce_loss.shape, - "<------------- Please check the shape of the Dice loss here!!", - ) - print( - bce_loss.shape, - "<------------- Please check the shape of the Dice loss here!!", - ) - print( - bce_loss.shape, - "<------------- Please check the shape of the Dice loss here!!", - ) - print( - bce_loss.shape, - "<------------- Please check the shape of the Dice loss here!!", - ) - print("Also check values!") # Combine the Dice and Cross Entropy losses - combined_loss = ( - self.lambda_dice * dice_loss + self.lambda_ce * bce_loss - ) # Combined loss should be per batch element + combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss return combined_loss @@ -172,7 +145,7 @@ def __init__( self.loss_fn = loss_fn self.weights = weights - def forward(self, inputs: list, targets: list) -> torch.Tensor: + def forward(self, inputs: list, targets: list, ds_labels: list) -> torch.Tensor: """ Compute the loss. @@ -182,6 +155,8 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: List of tensors of model outputs. targets : list List of tensors of target labels. + ds_labels : list + List of dataset labels for each batch element. Returns ------- @@ -189,8 +164,11 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: The calculated loss. """ loss = 0.0 - for weight, data, target in zip(self.weights, inputs, targets): - loss += weight * self.loss_fn(data, target) + ds_labels_loop = [ds_labels] * 5 + for weight, data, target, ds_label in zip( + self.weights, inputs, targets, ds_labels_loop + ): + loss += weight * self.loss_fn(data, target, ds_label) return loss @@ -223,8 +201,6 @@ class CombinedLoss(_Loss): will be excluded for cases where the dataset label is "ds1" or "ds2", and the second loss function will be excluded for cases where the dataset label is "ds1". - - """ def __init__( @@ -260,15 +236,13 @@ def forward( The calculated combined loss. """ loss = 0.0 - for loss_idx, (cur_loss, cur_weight, _skip_cases) in enumerate( - zip(self.losses, self.weights, self.apply_loss_not_for) + for loss_idx, (cur_loss, cur_weight) in enumerate( + zip(self.losses, self.weights) ): cur_loss_val = cur_loss(data, target) - print("CHECK WHETHER THE COMBINED LOSS IS CORRECTLY COMPUTED HERE!") - print("CHECK WHETHER THE COMBINED LOSS IS CORRECTLY COMPUTED HERE!") # Zero out losses for excluded cases - for batch_idx, ds_lab in ds_label: + for batch_idx, ds_lab in enumerate(ds_label): if ( "all" in self.loss_inclusion_tokens[loss_idx] or ds_lab in self.loss_inclusion_tokens[loss_idx] @@ -281,5 +255,5 @@ def forward( loss += cur_weight * cur_loss_val # Normalize loss - loss = loss / self.weights.sum() + loss = loss / sum(self.weights) return loss From 84c2b0bb25fcdbf488e4be826e1afd6639a5d30a Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:32:37 +0100 Subject: [PATCH 13/23] Implement Gaussian smoothing with torch to compute everything on GPU --- .../segmentation/training/surface_dice.py | 142 +++++++++++++++--- 1 file changed, 122 insertions(+), 20 deletions(-) diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py index deb3857..300dd7f 100644 --- a/src/membrain_seg/segmentation/training/surface_dice.py +++ b/src/membrain_seg/segmentation/training/surface_dice.py @@ -37,9 +37,10 @@ SOFTWARE. """ +import math + import torch import torch.nn.functional as F -from scipy.ndimage import gaussian_filter from torch.nn.functional import sigmoid from torch.nn.modules.loss import _Loss @@ -187,6 +188,109 @@ def soft_skel( return skel +def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: + """ + Creates a 3D Gaussian kernel using the specified size and sigma. + + Parameters + ---------- + size : int + The size of the Gaussian kernel. It determines the length of + each dimension of the cube. + sigma : float + The standard deviation of the Gaussian kernel. It controls + the spread of the Gaussian. + + Returns + ------- + torch.Tensor + A 3D tensor representing the Gaussian kernel. + + Notes + ----- + The function creates a Gaussian kernel, which is essentially a + cube of dimensions [size, size, size]. Each entry in the cube is + computed using the Gaussian function based on its distance from the center. + The kernel is normalized so that its total sum equals 1. + """ + # Define a coordinate grid centered at (0,0,0) + grid = torch.arange(size, dtype=torch.float32) - (size - 1) / 2 + # Create a 3D meshgrid + x, y, z = torch.meshgrid(grid, grid, grid) + xyz_grid = torch.stack([x, y, z], dim=-1) + + # Calculate the 3D Gaussian kernel + gaussian_kernel = torch.exp(-torch.sum(xyz_grid**2, dim=-1) / (2 * sigma**2)) + gaussian_kernel /= (2 * math.pi * sigma**2) ** (3 / 2) # Normalize + + # Ensure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel + + +gaussian_kernel_dict = {} +""" Not sure why, but moving the gaussian kernel to GPU takes surprisingly long.+ +So we precompute it, store it on GPU, and reuse it. +""" + + +def apply_gaussian_filter( + seg: torch.Tensor, kernel_size: int, sigma: float +) -> torch.Tensor: + """ + Apply a Gaussian filter to a segmentation tensor using PyTorch. + + This function convolves the input tensor with a Gaussian kernel. + The function creates or retrieves a Gaussian kernel based on the + specified size and standard deviation, and applies 3D convolution to each + channel of each batch item with appropriate padding to maintain spatial + dimensions. + + Parameters + ---------- + seg : torch.Tensor + The input segmentation tensor of shape (batch, channel, X, Y, Z). + kernel_size : int + The size of the Gaussian kernel, determining the length of each + dimension of the cube. + sigma : float + The standard deviation of the Gaussian kernel, controlling the spread. + + Returns + ------- + torch.Tensor + The filtered segmentation tensor of the same shape as input. + + Notes + ----- + This function uses a precomputed dictionary to enhance performance by + storing Gaussian kernels. If a kernel with the specified size and standard + deviation does not exist in the dictionary, it is created and added. The + function assumes the input tensor is a 5D tensor, applies 3D convolution + using the Gaussian kernel with padding to maintain spatial dimensions, and + it performs the operation separately for each channel of each batch item. + """ + # Create the Gaussian kernel or load it from the dictionary + global gaussian_kernel_dict + if (kernel_size, sigma) not in gaussian_kernel_dict.keys(): + gaussian_kernel_dict[(kernel_size, sigma)] = gaussian_kernel( + kernel_size, sigma + ).to(seg.device) + g_kernel = gaussian_kernel_dict[(kernel_size, sigma)] + + # Add batch and channel dimensions + g_kernel = g_kernel.view(1, 1, *g_kernel.size()) + # Apply the Gaussian filter to each channel + padding = kernel_size // 2 + + # Move the kernel to the same device as the segmentation tensor + g_kernel = g_kernel.to(seg.device) + + # Apply the Gaussian filter + filtered_seg = F.conv3d(seg, g_kernel, padding=padding, groups=seg.shape[1]) + return filtered_seg + + def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: """ Generate the skeleton of a ground truth segmentation. @@ -224,9 +328,9 @@ def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: skeleton points. """ gt_smooth = ( - gaussian_filter((gt_seg == 1) * 1.0, 2) * 1.5 - ) # The Gaussian smoothing parameters are work in progress - skel_gt = soft_skel(torch.from_numpy(gt_smooth) * 1.0, iter_=iterations) + apply_gaussian_filter((gt_seg == 1) * 1.0, kernel_size=15, sigma=2.0) * 1.5 + ) + skel_gt = soft_skel(gt_smooth, iter_=iterations) return skel_gt @@ -237,6 +341,7 @@ def masked_surface_dice( soft_skel_iterations: int = 3, smooth: float = 3.0, binary_prediction: bool = False, + reduction: str = "none", ) -> torch.Tensor: """ Compute the surface Dice loss with masking for ignore labels. @@ -259,6 +364,8 @@ def masked_surface_dice( Smoothing factor to avoid division by zero. binary_prediction : bool If True, the predicted segmentation is assumed to be binary. Default is False. + reduction : str + Specifies the reduction to apply to the output. Default is "none". Returns ------- @@ -271,31 +378,27 @@ def masked_surface_dice( # Compute soft skeletonization if binary_prediction: - skel_pred = get_GT_skeleton( - data.clone(), soft_skel_iterations, separate_pool=False - ) + skel_pred = get_GT_skeleton(data.clone(), soft_skel_iterations) else: skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False) - skel_true = get_GT_skeleton( - target.clone(), soft_skel_iterations, separate_pool=False - ) + skel_true = get_GT_skeleton(target.clone(), soft_skel_iterations) # Mask out ignore labels skel_pred[~mask] = 0 skel_true[~mask] = 0 # compute surface dice loss - print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") - print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") - print("ALSO CHECK DIMENSIONS AND VALUES HERE!!! (Surface dice loss)") - tprec = (torch.sum(torch.multiply(skel_pred, target), dim=0) + smooth) / ( - torch.sum(skel_pred, dim=0) + smooth - ) - tsens = (torch.sum(torch.multiply(skel_true, data), dim=0) + smooth) / ( - torch.sum(skel_true, dim=0) + smooth + tprec = ( + torch.sum(torch.multiply(skel_pred, target), dim=(1, 2, 3, 4)) + smooth + ) / (torch.sum(skel_pred, dim=(1, 2, 3, 4)) + smooth) + tsens = (torch.sum(torch.multiply(skel_true, data), dim=(1, 2, 3, 4)) + smooth) / ( + torch.sum(skel_true, dim=(1, 2, 3, 4)) + smooth ) surf_dice_loss = 2.0 * (tprec * tsens) / (tprec + tsens) - return surf_dice_loss + if reduction == "none": + return surf_dice_loss + elif reduction == "mean": + return torch.mean(surf_dice_loss) class IgnoreLabelSurfaceDiceLoss(_Loss): @@ -349,5 +452,4 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: smooth=self.smooth, ) surf_dice_loss = 1.0 - surf_dice_score - return surf_dice_loss From 9b3ae9b1190898faacea9156f0dd709f5561f3b0 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 14:52:30 +0100 Subject: [PATCH 14/23] Training summary printing --- .../training/training_param_summary.py | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 src/membrain_seg/segmentation/training/training_param_summary.py diff --git a/src/membrain_seg/segmentation/training/training_param_summary.py b/src/membrain_seg/segmentation/training/training_param_summary.py new file mode 100644 index 0000000..d964c90 --- /dev/null +++ b/src/membrain_seg/segmentation/training/training_param_summary.py @@ -0,0 +1,117 @@ +def print_training_parameters( + data_dir: str = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/MemBrain-seg/data", + log_dir: str = "logs/", + batch_size: int = 2, + num_workers: int = 8, + max_epochs: int = 1000, + aug_prob_to_one: bool = False, + use_deep_supervision: bool = False, + project_name: str = "membrain-seg_v0", + sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, +): + """ + Print a formatted overview of the training parameters with explanations. + + Parameters + ---------- + data_dir : str, optional + Path to the directory containing training data. + log_dir : str, optional + Path to the directory where logs should be stored. + batch_size : int, optional + Number of samples per batch of input data. + num_workers : int, optional + Number of subprocesses to use for data loading. + max_epochs : int, optional + Maximum number of epochs to train for. + aug_prob_to_one : bool, optional + If True, all augmentation probabilities are set to 1. + use_deep_supervision : bool, optional + If True, enables deep supervision in the U-Net model. + project_name : str, optional + Name of the project for logging purposes. + sub_name : str, optional + Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. + + Returns + ------- + None + """ + print("\033[1mTraining Parameters Overview:\033[0m\n") + print( + "Data Directory:\n '{}' \n Path to the directory containing " + "training data.".format(data_dir) + ) + print("————————————————————————————————————————————————————————") + print( + "Log Directory:\n '{}' \n Directory where logs and outputs will " + "be stored.".format(log_dir) + ) + print("————————————————————————————————————————————————————————") + print( + "Batch Size:\n {} \n Number of samples processed in a single batch.".format( + batch_size + ) + ) + print("————————————————————————————————————————————————————————") + print( + "Number of Workers:\n {} \n Subprocesses to use for data " + "loading.".format(num_workers) + ) + print("————————————————————————————————————————————————————————") + print(f"Max Epochs:\n {max_epochs} \n Maximum number of training epochs.") + print("————————————————————————————————————————————————————————") + aug_status = "Enabled" if aug_prob_to_one else "Disabled" + print( + "Augmentation Probability to One:\n {} \n If enabled, sets all " + "augmentation probabilities to 1. (strong augmentation)".format(aug_status) + ) + print("————————————————————————————————————————————————————————") + deep_sup_status = "Enabled" if use_deep_supervision else "Disabled" + print( + "Use Deep Supervision:\n {} \n If enabled, activates deep " + "supervision in model.".format(deep_sup_status) + ) + print("————————————————————————————————————————————————————————") + print( + "Project Name:\n '{}' \n Name identifier for the current" + " training session.".format(project_name) + ) + print("————————————————————————————————————————————————————————") + print( + "Sub Name:\n '{}' \n Additional sub-identifier for organizing" + " outputs.".format(sub_name) + ) + print("————————————————————————————————————————————————————————") + surf_dice_status = "Enabled" if use_surf_dice else "Disabled" + print( + "Use Surface Dice:\n {} \n If enabled, includes Surface-Dice in the loss " + "calculation.".format(surf_dice_status) + ) + print("————————————————————————————————————————————————————————") + print( + "Surface Dice Weight:\n {} \n Weighting of the Surface-Dice" + " loss, if enabled.".format(surf_dice_weight) + ) + print("————————————————————————————————————————————————————————") + if surf_dice_tokens: + tokens = ", ".join(surf_dice_tokens) + print( + "Surface Dice Tokens:\n [{}] \n Specific tokens used for " + "Surface-Dice loss. Other tokens will be neglected.".format(tokens) + ) + else: + print( + "Surface Dice Tokens:\n None \n No specific tokens are used for " + "Surface-Dice loss." + ) + print("\n") From 44bd6a145dfe2f46c2ed558b4ec60f63908da2d1 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 15:42:56 +0100 Subject: [PATCH 15/23] add dataset token to CLI --- src/membrain_seg/annotations/extract_patch_cli.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/membrain_seg/annotations/extract_patch_cli.py b/src/membrain_seg/annotations/extract_patch_cli.py index 339e476..a5074f2 100644 --- a/src/membrain_seg/annotations/extract_patch_cli.py +++ b/src/membrain_seg/annotations/extract_patch_cli.py @@ -21,6 +21,11 @@ def extract_patches( help="Path to the folder where extracted patches should be stored. \ (subdirectories will be created)", ), + ds_token: str = Option( # noqa: B008 + "other", + help="Dataset token. Important for distinguishing between different \ + datasets. Should NOT contain underscores!", + ), coords_file: str = Option( # noqa: B008 None, help="Path to a file containing coordinates for patch extraction. The file \ @@ -93,6 +98,7 @@ def extract_patches( coords=coords, out_dir=out_folder, idx_add=idx_add, + ds_token=ds_token, token=token, pad_value=pad_value, ) From b577c5ea41377eaa6fb9c0cb23dcb08e080d009f Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 15:43:46 +0100 Subject: [PATCH 16/23] Add dataset token to filename --- .../annotations/extract_patches.py | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/membrain_seg/annotations/extract_patches.py b/src/membrain_seg/annotations/extract_patches.py index d9628ab..7e5ff78 100644 --- a/src/membrain_seg/annotations/extract_patches.py +++ b/src/membrain_seg/annotations/extract_patches.py @@ -51,7 +51,7 @@ def pad_labels(patch, padding, pad_value=2.0): def get_out_files_and_patch_number( - token, out_folder_raw, out_folder_lab, patch_nr, idx_add + ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add ): """ Create filenames and corrected patch numbers. @@ -62,8 +62,10 @@ def get_out_files_and_patch_number( Parameters ---------- + ds_token : str + The dataset identifier used as a part of the filename. token : str - The unique identifier used as a part of the filename. + The tomogram identifier used as a part of the filename. out_folder_raw : str The directory path where raw data patches are stored. out_folder_lab : str @@ -96,27 +98,34 @@ def get_out_files_and_patch_number( """ patch_nr += idx_add out_file_patch = os.path.join( - out_folder_raw, token + "_patch" + str(patch_nr) + "_raw.nii.gz" + out_folder_raw, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz" ) out_file_patch_label = os.path.join( - out_folder_lab, token + "_patch" + str(patch_nr) + "_labels.nii.gz" + out_folder_lab, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz" ) exist_add = 0 while os.path.isfile(out_file_patch): exist_add += 1 out_file_patch = os.path.join( out_folder_raw, - token + "_patch" + str(patch_nr + exist_add) + "_raw.nii.gz", + ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz", ) out_file_patch_label = os.path.join( out_folder_lab, - token + "_patch" + str(patch_nr + exist_add) + "_labels.nii.gz", + ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz", ) return patch_nr + exist_add, out_file_patch, out_file_patch_label def extract_patches( - tomo_path, seg_path, coords, out_dir, idx_add=0, token=None, pad_value=2.0 + tomo_path, + seg_path, + coords, + out_dir, + ds_token="other", + token=None, + idx_add=0, + pad_value=2.0, ): """ Extracts 3D patches from a given tomogram and corresponding segmentation. @@ -133,11 +142,13 @@ def extract_patches( List of tuples where each tuple represents the 3D coordinates of a patch center. out_dir : str The output directory where the extracted patches will be saved. - idx_add : int, optional - The index addition for patch numbering, default is 0. + ds_token : str, optional + Dataset token to uniquely identify the dataset, default is 'other'. token : str, optional Token to uniquely identify the tomogram, default is None. If None, the base name of the tomogram file path is used. + idx_add : int, optional + The index addition for patch numbering, default is 0. pad_value: float, optional Borders of extracted patch are padded with this value ("ignore" label) @@ -170,7 +181,7 @@ def extract_patches( for patch_nr, cur_coords in enumerate(coords): patch_nr, out_file_patch, out_file_patch_label = get_out_files_and_patch_number( - token, out_folder_raw, out_folder_lab, patch_nr, idx_add + ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add ) print("Extracting patch nr", patch_nr, "from tomo", token) try: From 53739771a7ea2c78b08c7b13b8e3b1bb698e8b6c Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 15:44:38 +0100 Subject: [PATCH 17/23] Update warnings --- src/membrain_seg/annotations/merge_corrections.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/membrain_seg/annotations/merge_corrections.py b/src/membrain_seg/annotations/merge_corrections.py index 361c64a..e2c0d68 100644 --- a/src/membrain_seg/annotations/merge_corrections.py +++ b/src/membrain_seg/annotations/merge_corrections.py @@ -46,13 +46,15 @@ def get_corrections_from_folder(folder_name, orig_pred_file): or filename.startswith("Ignore") or filename.startswith("ignore") ): - print("ATTENTION! Not processing", filename) - print("Is this intended?") + print( + "File does not fit into Add/Remove/Ignore naming! " "Not processing", + filename, + ) continue readdata = sitk.GetArrayFromImage( sitk.ReadImage(os.path.join(folder_name, filename)) ) - print("Adding file", filename, "<--") + print("Adding file", filename) if filename.startswith("Add") or filename.startswith("add"): add_patch += readdata From b4c9e438041471586fc1461fe32c26533176f717 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 15:45:43 +0100 Subject: [PATCH 18/23] Fix bug for accuracy masking --- src/membrain_seg/segmentation/training/metric_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/membrain_seg/segmentation/training/metric_utils.py b/src/membrain_seg/segmentation/training/metric_utils.py index f30a258..ad8f300 100644 --- a/src/membrain_seg/segmentation/training/metric_utils.py +++ b/src/membrain_seg/segmentation/training/metric_utils.py @@ -34,7 +34,7 @@ def masked_accuracy( mask = ( y_gt == ignore_label if ignore_label is not None - else torch.ones_like(y_gt).bool() + else torch.zeros_like(y_gt).bool() ) acc = (threshold_function(y_pred, threshold_value=threshold_value) == y_gt).float() acc[mask] = 0.0 From 6d1b0e594f062ea8f1079dbbcbc5524cd0d3933a Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 16:44:25 +0100 Subject: [PATCH 19/23] Fix Dice reduction to scalar --- src/membrain_seg/segmentation/networks/unet.py | 2 +- src/membrain_seg/segmentation/training/optim_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 1daac61..711f1d1 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -115,7 +115,7 @@ def __init__( losses = [] weights = [] loss_inclusion_tokens = [] - ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="none") losses.append(ignore_dice_loss) weights.append(1.0) loss_inclusion_tokens.append(["all"]) # Apply to every element diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index bd807ab..e2c9572 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -111,6 +111,10 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Combine the Dice and Cross Entropy losses combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss + if self.reduction == "mean": + combined_loss = combined_loss.mean() + elif self.reduction == "sum": + combined_loss = combined_loss.sum() return combined_loss From bc8acf004c4729c251aba7fbe3435a5c226296a7 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 3 Jan 2024 16:45:40 +0100 Subject: [PATCH 20/23] Make test compatible with CombinedLoss --- tests/membrain_seg/training/test_optim_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/membrain_seg/training/test_optim_utils.py b/tests/membrain_seg/training/test_optim_utils.py index 4151d7a..dd0433c 100644 --- a/tests/membrain_seg/training/test_optim_utils.py +++ b/tests/membrain_seg/training/test_optim_utils.py @@ -12,6 +12,7 @@ def test_loss_fn_correctness(): import torch from membrain_seg.segmentation.training.optim_utils import ( + CombinedLoss, DeepSuperVisionLoss, IgnoreLabelDiceCELoss, ) @@ -73,17 +74,24 @@ def extend_labels(labels): pred_labels[2][pred_labels[2] < 0.0] = 0.0 ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + combined_loss = CombinedLoss( + losses=[ignore_dice_loss], weights=[1.0], loss_inclusion_tokens=["ds1"] + ) losses = test_ignore_dice_loss(ignore_dice_loss, pred_labels, gt_labels) assert losses[0] == losses[1] == losses[2] == losses[4] != losses[3] deep_supervision_loss = DeepSuperVisionLoss( - ignore_dice_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] + combined_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] ) gt_labels_ds = extend_labels(gt_labels) ds_losses = [] for pred_label in pred_labels: pred_labels_ds = extend_labels(pred_label) - ds_losses.append(deep_supervision_loss(pred_labels_ds, gt_labels_ds)) + ds_losses.append( + deep_supervision_loss( + pred_labels_ds, gt_labels_ds, ["ds1"] * len(gt_labels_ds) + ) + ) assert ds_losses[0] == ds_losses[1] == ds_losses[2] == ds_losses[4] != ds_losses[3] print("All ignore loss assertions passed.") From 9f39fc756bec3579f70bc8c5a2275beea5c3c190 Mon Sep 17 00:00:00 2001 From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com> Date: Wed, 3 Jan 2024 17:04:04 +0100 Subject: [PATCH 21/23] Fix default path --- .../segmentation/training/training_param_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/membrain_seg/segmentation/training/training_param_summary.py b/src/membrain_seg/segmentation/training/training_param_summary.py index d964c90..67277c8 100644 --- a/src/membrain_seg/segmentation/training/training_param_summary.py +++ b/src/membrain_seg/segmentation/training/training_param_summary.py @@ -1,5 +1,5 @@ def print_training_parameters( - data_dir: str = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/MemBrain-seg/data", + data_dir: str = "", log_dir: str = "logs/", batch_size: int = 2, num_workers: int = 8, From 054761180b39ce03e97ad024f4ef248cb680af3e Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Mon, 22 Jan 2024 18:31:06 +0100 Subject: [PATCH 22/23] Raise Error when reduction is not defined --- src/membrain_seg/segmentation/training/optim_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index e2c9572..a16e136 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -115,6 +115,11 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: combined_loss = combined_loss.mean() elif self.reduction == "sum": combined_loss = combined_loss.sum() + else: + raise ValueError( + f"Invalid reduction type {self.reduction}. " + "Valid options are 'mean' and 'sum'." + ) return combined_loss From 346d850c50d7fb215cb24563fc472c467d31e7a2 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Mon, 22 Jan 2024 18:32:22 +0100 Subject: [PATCH 23/23] Add required dimensions to docstrings --- src/membrain_seg/segmentation/training/surface_dice.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py index 300dd7f..bc2c6b4 100644 --- a/src/membrain_seg/segmentation/training/surface_dice.py +++ b/src/membrain_seg/segmentation/training/surface_dice.py @@ -271,7 +271,6 @@ def apply_gaussian_filter( it performs the operation separately for each channel of each batch item. """ # Create the Gaussian kernel or load it from the dictionary - global gaussian_kernel_dict if (kernel_size, sigma) not in gaussian_kernel_dict.keys(): gaussian_kernel_dict[(kernel_size, sigma)] = gaussian_kernel( kernel_size, sigma @@ -354,8 +353,10 @@ def masked_surface_dice( ---------- data : torch.Tensor Tensor of model outputs representing the predicted segmentation. + Expected shape: (B, C, D, H, W) target : torch.Tensor Tensor of target labels representing the ground truth segmentation. + Expected shape: (B, 1, D, H, W) ignore_label : int The label value to be ignored when computing the loss. soft_skel_iterations : int @@ -435,8 +436,10 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ---------- data : torch.Tensor Tensor of model outputs. + Expected shape: (B, C, D, H, W) target : torch.Tensor Tensor of target labels. + Expected shape: (B, 1, D, H, W) Returns -------