diff --git a/src/membrain_seg/annotations/extract_patch_cli.py b/src/membrain_seg/annotations/extract_patch_cli.py index 339e476..a5074f2 100644 --- a/src/membrain_seg/annotations/extract_patch_cli.py +++ b/src/membrain_seg/annotations/extract_patch_cli.py @@ -21,6 +21,11 @@ def extract_patches( help="Path to the folder where extracted patches should be stored. \ (subdirectories will be created)", ), + ds_token: str = Option( # noqa: B008 + "other", + help="Dataset token. Important for distinguishing between different \ + datasets. Should NOT contain underscores!", + ), coords_file: str = Option( # noqa: B008 None, help="Path to a file containing coordinates for patch extraction. The file \ @@ -93,6 +98,7 @@ def extract_patches( coords=coords, out_dir=out_folder, idx_add=idx_add, + ds_token=ds_token, token=token, pad_value=pad_value, ) diff --git a/src/membrain_seg/annotations/extract_patches.py b/src/membrain_seg/annotations/extract_patches.py index d9628ab..7e5ff78 100644 --- a/src/membrain_seg/annotations/extract_patches.py +++ b/src/membrain_seg/annotations/extract_patches.py @@ -51,7 +51,7 @@ def pad_labels(patch, padding, pad_value=2.0): def get_out_files_and_patch_number( - token, out_folder_raw, out_folder_lab, patch_nr, idx_add + ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add ): """ Create filenames and corrected patch numbers. @@ -62,8 +62,10 @@ def get_out_files_and_patch_number( Parameters ---------- + ds_token : str + The dataset identifier used as a part of the filename. token : str - The unique identifier used as a part of the filename. + The tomogram identifier used as a part of the filename. out_folder_raw : str The directory path where raw data patches are stored. out_folder_lab : str @@ -96,27 +98,34 @@ def get_out_files_and_patch_number( """ patch_nr += idx_add out_file_patch = os.path.join( - out_folder_raw, token + "_patch" + str(patch_nr) + "_raw.nii.gz" + out_folder_raw, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz" ) out_file_patch_label = os.path.join( - out_folder_lab, token + "_patch" + str(patch_nr) + "_labels.nii.gz" + out_folder_lab, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz" ) exist_add = 0 while os.path.isfile(out_file_patch): exist_add += 1 out_file_patch = os.path.join( out_folder_raw, - token + "_patch" + str(patch_nr + exist_add) + "_raw.nii.gz", + ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz", ) out_file_patch_label = os.path.join( out_folder_lab, - token + "_patch" + str(patch_nr + exist_add) + "_labels.nii.gz", + ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz", ) return patch_nr + exist_add, out_file_patch, out_file_patch_label def extract_patches( - tomo_path, seg_path, coords, out_dir, idx_add=0, token=None, pad_value=2.0 + tomo_path, + seg_path, + coords, + out_dir, + ds_token="other", + token=None, + idx_add=0, + pad_value=2.0, ): """ Extracts 3D patches from a given tomogram and corresponding segmentation. @@ -133,11 +142,13 @@ def extract_patches( List of tuples where each tuple represents the 3D coordinates of a patch center. out_dir : str The output directory where the extracted patches will be saved. - idx_add : int, optional - The index addition for patch numbering, default is 0. + ds_token : str, optional + Dataset token to uniquely identify the dataset, default is 'other'. token : str, optional Token to uniquely identify the tomogram, default is None. If None, the base name of the tomogram file path is used. + idx_add : int, optional + The index addition for patch numbering, default is 0. pad_value: float, optional Borders of extracted patch are padded with this value ("ignore" label) @@ -170,7 +181,7 @@ def extract_patches( for patch_nr, cur_coords in enumerate(coords): patch_nr, out_file_patch, out_file_patch_label = get_out_files_and_patch_number( - token, out_folder_raw, out_folder_lab, patch_nr, idx_add + ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add ) print("Extracting patch nr", patch_nr, "from tomo", token) try: diff --git a/src/membrain_seg/annotations/merge_corrections.py b/src/membrain_seg/annotations/merge_corrections.py index 361c64a..e2c0d68 100644 --- a/src/membrain_seg/annotations/merge_corrections.py +++ b/src/membrain_seg/annotations/merge_corrections.py @@ -46,13 +46,15 @@ def get_corrections_from_folder(folder_name, orig_pred_file): or filename.startswith("Ignore") or filename.startswith("ignore") ): - print("ATTENTION! Not processing", filename) - print("Is this intended?") + print( + "File does not fit into Add/Remove/Ignore naming! " "Not processing", + filename, + ) continue readdata = sitk.GetArrayFromImage( sitk.ReadImage(os.path.join(folder_name, filename)) ) - print("Adding file", filename, "<--") + print("Adding file", filename) if filename.startswith("Add") or filename.startswith("add"): add_patch += readdata diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py index ac3d394..8cc6132 100644 --- a/src/membrain_seg/segmentation/cli/train_cli.py +++ b/src/membrain_seg/segmentation/cli/train_cli.py @@ -1,4 +1,7 @@ +from typing import List, Optional + from typer import Option +from typing_extensions import Annotated from ..train import train as _train from .cli import OPTION_PROMPT_KWARGS as PKWARGS @@ -70,7 +73,7 @@ def train_advanced( help="Batch size for training.", ), num_workers: int = Option( # noqa: B008 - 1, + 8, help="Number of worker threads for loading data", ), max_epochs: int = Option( # noqa: B008 @@ -84,6 +87,22 @@ def train_advanced( but also severely increases training time.\ Pass "True" or "False".', ), + use_surface_dice: bool = Option( # noqa: B008 + False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' + ), + surface_dice_weight: float = Option( # noqa: B008 + 1.0, help="Scaling factor for the Surface-Dice loss. " + ), + surface_dice_tokens: Annotated[ + Optional[List[str]], + Option( + help='List of tokens to \ + use for the Surface-Dice loss. \ + Pass tokens separately:\ + For example, train_advanced --surface_dice_tokens "ds1" \ + --surface_dice_tokens "ds2"' + ), + ] = None, use_deep_supervision: bool = Option( # noqa: B008 True, help='Whether to use deep supervision. Pass "True" or "False".' ), @@ -119,6 +138,12 @@ def train_advanced( If set to False, data augmentation still happens, but not as frequently. More data augmentation can lead to a better performance, but also increases the training time substantially. + use_surface_dice : bool + Determines whether to use Surface-Dice loss, by default True. + surface_dice_weight : float + Scaling factor for the Surface-Dice loss, by default 1.0. + surface_dice_tokens : list + List of tokens to use for the Surface-Dice loss, by default ["all"]. use_deep_supervision : bool Determines whether to use deep supervision, by default True. project_name : str @@ -140,6 +165,9 @@ def train_advanced( max_epochs=max_epochs, aug_prob_to_one=aug_prob_to_one, use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surface_dice, + surf_dice_weight=surface_dice_weight, + surf_dice_tokens=surface_dice_tokens, project_name=project_name, sub_name=sub_name, ) diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py index d856883..9b66dda 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py @@ -1,7 +1,6 @@ import os from typing import Dict -# from skimage import io import imageio as io import numpy as np from torch.utils.data import Dataset @@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: "label": np.expand_dims(self.labels[idx], 0), } idx_dict = self.transforms(idx_dict) + idx_dict["dataset"] = self.dataset_labels[idx] return idx_dict def __len__(self) -> int: @@ -126,6 +126,7 @@ def load_data(self) -> None: print("Loading images into dataset.") self.imgs = [] self.labels = [] + self.dataset_labels = [] for entry in self.data_paths: label = read_nifti( entry[1] @@ -137,6 +138,7 @@ def load_data(self) -> None: img = np.transpose(img, (1, 2, 0)) self.imgs.append(img) self.labels.append(label) + self.dataset_labels.append(get_dataset_token(entry[0])) def initialize_imgs_paths(self) -> None: """ @@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None: os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"), test_sample["label"][1][0, :, :, num_mask], ) + + +def get_dataset_token(patch_name): + """ + Get the dataset token from the patch name. + + Parameters + ---------- + patch_name : str + The name of the patch. + + Returns + ------- + str + The dataset token. + + """ + basename = os.path.basename(patch_name) + dataset_token = basename.split("_")[0] + return dataset_token diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 7a67310..711f1d1 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -8,17 +8,13 @@ from monai.transforms import AsDiscrete, Compose, EnsureType, Lambda from ..training.metric_utils import masked_accuracy, threshold_function - -# from monai.networks.nets import UNet as MonaiUnet -# The normal Monai DynUNet upsamples low-resolution layers to compare directly to GT -# My implementation leaves them in low resolution and compares to down-sampled GT -# Not sure which implementation is better -# To be discussed with Alister & Kevin from ..training.optim_utils import ( + CombinedLoss, DeepSuperVisionLoss, - DynUNetDirectDeepSupervision, # I like to use deep supervision + DynUNetDirectDeepSupervision, IgnoreLabelDiceCELoss, ) +from ..training.surface_dice import IgnoreLabelSurfaceDiceLoss, masked_surface_dice class SemanticSegmentationUnet(pl.LightningModule): @@ -62,6 +58,12 @@ class SemanticSegmentationUnet(pl.LightningModule): The maximum number of epochs for training. use_deep_supervision : bool, default=False Whether to use deep supervision. + use_surf_dice : bool, default=False + Whether to use surface dice loss. + surf_dice_weight : float, default=1.0 + The weight for the surface dice loss. + surf_dice_tokens : list, default=[] + The tokens for which to compute the surface dice loss. """ @@ -80,6 +82,9 @@ def __init__( roi_size: Tuple[int, ...] = (160, 160, 160), max_epochs: int = 1000, use_deep_supervision: bool = False, + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, ): super().__init__() @@ -102,16 +107,39 @@ def __init__( upsample_kernel_size=(1, 2, 2, 2, 2, 2), filters=channels, res_block=True, - # norm_name="INSTANCE", - # norm=Norm.INSTANCE, # I like the instance normalization better than - # batchnorm in this case, as we will probably have - # only small batch sizes, making BN more noisy deep_supervision=True, deep_supr_num=2, ) - ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + + ### Build up loss function + losses = [] + weights = [] + loss_inclusion_tokens = [] + ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="none") + losses.append(ignore_dice_loss) + weights.append(1.0) + loss_inclusion_tokens.append(["all"]) # Apply to every element + + if use_surf_dice: + if surf_dice_tokens is None: + surf_dice_tokens = ["all"] + ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss( + ignore_label=2, soft_skel_iterations=5 + ) + losses.append(ignore_surf_dice_loss) + weights.append(surf_dice_weight) + loss_inclusion_tokens.append(surf_dice_tokens) + + scaled_weights = [entry / sum(weights) for entry in weights] + + loss_function = CombinedLoss( + losses=losses, + weights=scaled_weights, + loss_inclusion_tokens=loss_inclusion_tokens, + ) + self.loss_function = DeepSuperVisionLoss( - ignore_dice_loss, + loss_function, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] if use_deep_supervision else [1.0, 0.0, 0.0, 0.0, 0.0], @@ -143,7 +171,9 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] self.running_train_acc = 0.0 + self.running_train_surf_dice = 0.0 self.running_val_acc = 0.0 + self.running_val_surf_dice = 0.0 def forward(self, x) -> torch.Tensor: """Implementation of the forward pass. @@ -180,9 +210,9 @@ def training_step( See the pytorch-lightning module documentation for details. """ - images, labels = batch["image"], batch["label"] + images, labels, ds_label = batch["image"], batch["label"], batch["dataset"] output = self.forward(images) - loss = self.loss_function(output, labels) + loss = self.loss_function(output, labels, ds_label) stats_dict = {"train_loss": loss, "train_number": output[0].shape[0]} self.training_step_outputs.append(stats_dict) @@ -190,6 +220,17 @@ def training_step( masked_accuracy(output[0], labels[0], ignore_label=2.0, threshold_value=0.0) * output[0].shape[0] ) + self.running_train_surf_dice += ( + masked_surface_dice( + data=output[0].detach(), + target=labels[0].detach(), + ignore_label=2.0, + soft_skel_iterations=5, + smooth=1.0, + reduction="mean", + ) + * output[0].shape[0] + ) return {"loss": loss} @@ -207,13 +248,17 @@ def on_train_epoch_end(self): mean_train_loss = torch.tensor(train_loss / num_items) mean_train_acc = self.running_train_acc / num_items + mean_train_surf_dice = self.running_train_surf_dice / num_items self.running_train_acc = 0.0 - self.log("train_loss", mean_train_loss) # , batch_size=num_items) - self.log("train_acc", mean_train_acc) # , batch_size=num_items) + self.running_train_surf_dice = 0.0 + self.log("train_loss", mean_train_loss) + self.log("train_acc", mean_train_acc) + self.log("train_surf_dice", mean_train_surf_dice) self.training_step_outputs = [] print("EPOCH Training loss", mean_train_loss.item()) print("EPOCH Training acc", mean_train_acc.item()) + print("EPOCH Training surface dice", mean_train_surf_dice.item()) # Accuracy not the most informative metric, but a good sanity check return {"train_loss": mean_train_loss} @@ -224,13 +269,9 @@ def validation_step(self, batch, batch_idx): using a sliding window. See the pytorch-lightning module documentation for details. """ - images, labels = batch[self.image_key], batch[self.label_key] - # sw_batch_size = 4 - # outputs = sliding_window_inference( - # images, self.roi_size, sw_batch_size, self.forward - # ) + images, labels, ds_label = batch["image"], batch["label"], batch["dataset"] outputs = self.forward(images) - loss = self.loss_function(outputs, labels) + loss = self.loss_function(outputs, labels, ds_label) # Cloning and adjusting preds & labels for Dice. # Could also use the same labels, but maybe we want to @@ -254,6 +295,18 @@ def validation_step(self, batch, batch_idx): ) * outputs[0].shape[0] ) + + self.running_val_surf_dice += ( + masked_surface_dice( + data=outputs[0].detach(), + target=labels[0].detach(), + ignore_label=2.0, + soft_skel_iterations=5, + smooth=1.0, + reduction="mean", + ) + * outputs[0].shape[0] + ) return stats_dict def on_validation_epoch_end(self): @@ -270,13 +323,17 @@ def on_validation_epoch_end(self): mean_val_loss = torch.tensor(val_loss / num_items) mean_val_acc = self.running_val_acc / num_items + mean_val_surf_dice = self.running_val_surf_dice / num_items self.running_val_acc = 0.0 - self.log("val_loss", mean_val_loss), # batch_size=num_items) - self.log("val_dice", mean_val_dice) # , batch_size=num_items) + self.running_val_surf_dice = 0.0 + self.log("val_loss", mean_val_loss), + self.log("val_dice", mean_val_dice) + self.log("val_surf_dice", mean_val_surf_dice) self.log("val_accuracy", mean_val_acc) self.validation_step_outputs = [] print("EPOCH Validation loss", mean_val_loss.item()) print("EPOCH Validation dice", mean_val_dice) + print("EPOCH Validation surface dice", mean_val_surf_dice.item()) print("EPOCH Validation acc", mean_val_acc.item()) return {"val_loss": mean_val_loss, "val_metric": mean_val_dice} diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 2f077f2..9c576f5 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -9,6 +9,9 @@ MemBrainSegDataModule, ) from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.training_param_summary import ( + print_training_parameters, +) warnings.filterwarnings("ignore", category=UserWarning, module="torch._tensor") warnings.filterwarnings("ignore", category=UserWarning, module="monai.data") @@ -24,6 +27,9 @@ def train( use_deep_supervision: bool = False, project_name: str = "membrain-seg_v0", sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, ): """ Train the model on the specified data. @@ -52,11 +58,31 @@ def train( Name of the project for logging purposes. sub_name : str, optional Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. Returns ------- None """ + print_training_parameters( + data_dir=data_dir, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, + ) # Set up the data module data_module = MemBrainSegDataModule( data_dir=data_dir, @@ -67,15 +93,16 @@ def train( # Set up the model model = SemanticSegmentationUnet( - max_epochs=max_epochs, use_deep_supervision=use_deep_supervision + max_epochs=max_epochs, + use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, ) project_name = project_name checkpointing_name = project_name + "_" + sub_name # Set up logging - wandb_logger = pl_loggers.WandbLogger( - project=project_name, log_model=False, save_code=True - ) csv_logger = pl_loggers.CSVLogger(log_dir) # Set up model checkpointing @@ -106,7 +133,7 @@ def on_epoch_start(self, trainer, pl_module): # Set up the trainer trainer = pl.Trainer( precision="16-mixed", - logger=[csv_logger, wandb_logger], + logger=[csv_logger], callbacks=[ checkpoint_callback_val_loss, checkpoint_callback_regular, diff --git a/src/membrain_seg/segmentation/training/metric_utils.py b/src/membrain_seg/segmentation/training/metric_utils.py index f30a258..ad8f300 100644 --- a/src/membrain_seg/segmentation/training/metric_utils.py +++ b/src/membrain_seg/segmentation/training/metric_utils.py @@ -34,7 +34,7 @@ def masked_accuracy( mask = ( y_gt == ignore_label if ignore_label is not None - else torch.ones_like(y_gt).bool() + else torch.zeros_like(y_gt).bool() ) acc = (threshold_function(y_pred, threshold_value=threshold_value) == y_gt).float() acc[mask] = 0.0 diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index 8dbdc6b..a16e136 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss): def __init__( self, ignore_label: int, - reduction: str = "mean", + reduction: str = "none", lambda_dice: float = 1.0, lambda_ce: float = 1.0, **kwargs, @@ -95,11 +95,31 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: orig_data, target_tensor, reduction="none" ) bce_loss[~mask] = 0.0 - bce_loss = torch.sum(bce_loss) / torch.sum(mask) - dice_loss = self.dice_loss(data, target, mask) + # TODO: Check if this is correct: I adjusted the loss to be + # computed per batch element + bce_loss = torch.sum(bce_loss, dim=(1, 2, 3, 4)) / torch.sum( + mask, dim=(1, 2, 3, 4) + ) + # Compute Dice loss separately for each batch element + dice_loss = torch.zeros_like(bce_loss) + for batch_idx in range(data.shape[0]): + dice_loss[batch_idx] = self.dice_loss( + data[batch_idx].unsqueeze(0), + target[batch_idx].unsqueeze(0), + mask[batch_idx].unsqueeze(0), + ) # Combine the Dice and Cross Entropy losses combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss + if self.reduction == "mean": + combined_loss = combined_loss.mean() + elif self.reduction == "sum": + combined_loss = combined_loss.sum() + else: + raise ValueError( + f"Invalid reduction type {self.reduction}. " + "Valid options are 'mean' and 'sum'." + ) return combined_loss @@ -134,7 +154,7 @@ def __init__( self.loss_fn = loss_fn self.weights = weights - def forward(self, inputs: list, targets: list) -> torch.Tensor: + def forward(self, inputs: list, targets: list, ds_labels: list) -> torch.Tensor: """ Compute the loss. @@ -144,6 +164,8 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: List of tensors of model outputs. targets : list List of tensors of target labels. + ds_labels : list + List of dataset labels for each batch element. Returns ------- @@ -151,6 +173,96 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: The calculated loss. """ loss = 0.0 - for weight, data, target in zip(self.weights, inputs, targets): - loss += weight * self.loss_fn(data, target) + ds_labels_loop = [ds_labels] * 5 + for weight, data, target, ds_label in zip( + self.weights, inputs, targets, ds_labels_loop + ): + loss += weight * self.loss_fn(data, target, ds_label) + return loss + + +class CombinedLoss(_Loss): + """ + Combine multiple loss functions into a single one. + + Parameters + ---------- + losses : List[Callable] + A list of loss function instances. + weights : List[float] + List of weights corresponding to each loss function (must + be of same length as losses). + loss_inclusion_tokens : List[List[str]] + A list of lists containing tokens for each loss function. + Each sublist corresponds to a loss function and contains + tokens for which the loss should be included. + If the list contains "all", then the loss will be included + for all cases. + + Notes + ----- + IMPORTANT: Loss functions need to return a tensors containing the + loss for each batch element. + + The loss_exclusion_tokens parameter is used to exclude certain + cases from the loss calculation. For example, if the loss_exclusion_tokens + parameter is [["ds1", "ds2"], ["ds1"]], then the first loss function + will be excluded for cases where the dataset label is "ds1" or "ds2", + and the second loss function will be excluded for cases where the + dataset label is "ds1". + """ + + def __init__( + self, + losses: list, + weights: list, + loss_inclusion_tokens: list, + **kwargs, + ) -> None: + super().__init__() + self.losses = losses + self.weights = weights + self.loss_inclusion_tokens = loss_inclusion_tokens + + def forward( + self, data: torch.Tensor, target: torch.Tensor, ds_label: list + ) -> torch.Tensor: + """ + Compute the combined loss. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs. + target : torch.Tensor + Tensor of target labels. + ds_label : List[str] + List of dataset labels for each batch element. + + Returns + ------- + torch.Tensor + The calculated combined loss. + """ + loss = 0.0 + for loss_idx, (cur_loss, cur_weight) in enumerate( + zip(self.losses, self.weights) + ): + cur_loss_val = cur_loss(data, target) + + # Zero out losses for excluded cases + for batch_idx, ds_lab in enumerate(ds_label): + if ( + "all" in self.loss_inclusion_tokens[loss_idx] + or ds_lab in self.loss_inclusion_tokens[loss_idx] + ): + continue + cur_loss_val[batch_idx] = 0.0 + + # Aggregate loss + cur_loss_val = cur_loss_val.sum() / ((cur_loss_val != 0.0).sum() + 1e-3) + loss += cur_weight * cur_loss_val + + # Normalize loss + loss = loss / sum(self.weights) return loss diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py new file mode 100644 index 0000000..bc2c6b4 --- /dev/null +++ b/src/membrain_seg/segmentation/training/surface_dice.py @@ -0,0 +1,458 @@ +""" +Surface Dice implementation. + +Adapted from: clDice - A Novel Topology-Preserving Loss Function for Tubular +Structure Segmentation +Original Authors: Johannes C. Paetzold and Suprosanna Shit +Sources: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/ + soft_skeleton.py + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py +License: MIT License. + +The following code is a modification of the original clDice implementation. +Modifications were made to include additional functionality and integrate +with new project requirements. The original license and copyright notice are +provided below. + +MIT License + +Copyright (c) 2021 Johannes C. Paetzold and Suprosanna Shit + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import math + +import torch +import torch.nn.functional as F +from torch.nn.functional import sigmoid +from torch.nn.modules.loss import _Loss + + +def soft_erode(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: + """ + Apply soft erosion operation to the input image. + + Soft erosion is achieved by applying a min-pooling operation to the input image. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W) + separate_pool : bool, optional + If True, perform separate 3D max-pooling operations along different axes. + Default is False. + + Returns + ------- + torch.Tensor + Eroded image tensor with the same shape as the input. + + Raises + ------ + ValueError + If the input tensor has an unsupported number of dimensions. + + Notes + ----- + - The soft erosion can be performed with separate 3D min-pooling operations + along different axes if separate_pool is True, or with a single + 3D min-pooling operation with a kernel of size (3, 3, 3) if + separate_pool is False. + """ + assert len(img.shape) == 5 + if separate_pool: + p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) + p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) + p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) + return torch.min(torch.min(p1, p2), p3) + p4 = -F.max_pool3d(-img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) + return p4 + + +def soft_dilate(img: torch.Tensor) -> torch.Tensor: + """ + Apply soft dilation operation to the input image. + + Soft dilation is achieved by applying a max-pooling operation to the input image. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + + Returns + ------- + torch.Tensor + Dilated image tensor with the same shape as the input. + + Raises + ------ + ValueError + If the input tensor has an unsupported number of dimensions. + + Notes + ----- + - For 5D input, the soft dilation is performed using a 3D max-pooling operation + with a kernel of size (3, 3, 3). + """ + assert len(img.shape) == 5 + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) + + +def soft_open(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: + """ + Apply soft opening operation to the input image. + + Soft opening is achieved by applying soft erosion followed by soft dilation. + The intention of soft opening is to remove thin membranes from the segmentation. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + separate_pool : bool, optional + If True, perform separate erosion and dilation operations. Default is False. + + Returns + ------- + torch.Tensor + Opened image tensor with the same shape as the input. + + Notes + ----- + - Soft opening is performed by applying soft erosion followed by soft dilation + to the input image. + - For 5D input, separate erosion and dilation can be performed if separate_pool + is True. + """ + return soft_dilate(soft_erode(img, separate_pool=separate_pool)) + + +def soft_skel( + img: torch.Tensor, iter_: int, separate_pool: bool = False +) -> torch.Tensor: + """ + Compute the soft skeleton of the input image. + + The skeleton is computed by applying soft erosion iteratively to the input image. + In each iteration, the difference between the input image and the "opened" image is + computed and added to the skeleton. + + Reasoning: if there is a difference between the input image and the "opened" image, + there must be a thin membrane skeleton in the input image that was removed by the + opening operation. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + iter_ : int + Number of iterations for skeletonization. + separate_pool : bool, optional + If True, perform separate erosion and dilation operations. + Default is False. + + Returns + ------- + torch.Tensor + Soft skeleton image tensor with the same shape as the input. + + Notes + ----- + - Separate erosion can be performed if separate_pool is True. + """ + img1 = soft_open(img, separate_pool=separate_pool) + skel = F.relu(img - img1) + for _j in range(iter_): + img = soft_erode(img) + img1 = soft_open(img, separate_pool=separate_pool) + delta = F.relu(img - img1) + skel = skel + F.relu(delta - skel * delta) + return skel + + +def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: + """ + Creates a 3D Gaussian kernel using the specified size and sigma. + + Parameters + ---------- + size : int + The size of the Gaussian kernel. It determines the length of + each dimension of the cube. + sigma : float + The standard deviation of the Gaussian kernel. It controls + the spread of the Gaussian. + + Returns + ------- + torch.Tensor + A 3D tensor representing the Gaussian kernel. + + Notes + ----- + The function creates a Gaussian kernel, which is essentially a + cube of dimensions [size, size, size]. Each entry in the cube is + computed using the Gaussian function based on its distance from the center. + The kernel is normalized so that its total sum equals 1. + """ + # Define a coordinate grid centered at (0,0,0) + grid = torch.arange(size, dtype=torch.float32) - (size - 1) / 2 + # Create a 3D meshgrid + x, y, z = torch.meshgrid(grid, grid, grid) + xyz_grid = torch.stack([x, y, z], dim=-1) + + # Calculate the 3D Gaussian kernel + gaussian_kernel = torch.exp(-torch.sum(xyz_grid**2, dim=-1) / (2 * sigma**2)) + gaussian_kernel /= (2 * math.pi * sigma**2) ** (3 / 2) # Normalize + + # Ensure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel + + +gaussian_kernel_dict = {} +""" Not sure why, but moving the gaussian kernel to GPU takes surprisingly long.+ +So we precompute it, store it on GPU, and reuse it. +""" + + +def apply_gaussian_filter( + seg: torch.Tensor, kernel_size: int, sigma: float +) -> torch.Tensor: + """ + Apply a Gaussian filter to a segmentation tensor using PyTorch. + + This function convolves the input tensor with a Gaussian kernel. + The function creates or retrieves a Gaussian kernel based on the + specified size and standard deviation, and applies 3D convolution to each + channel of each batch item with appropriate padding to maintain spatial + dimensions. + + Parameters + ---------- + seg : torch.Tensor + The input segmentation tensor of shape (batch, channel, X, Y, Z). + kernel_size : int + The size of the Gaussian kernel, determining the length of each + dimension of the cube. + sigma : float + The standard deviation of the Gaussian kernel, controlling the spread. + + Returns + ------- + torch.Tensor + The filtered segmentation tensor of the same shape as input. + + Notes + ----- + This function uses a precomputed dictionary to enhance performance by + storing Gaussian kernels. If a kernel with the specified size and standard + deviation does not exist in the dictionary, it is created and added. The + function assumes the input tensor is a 5D tensor, applies 3D convolution + using the Gaussian kernel with padding to maintain spatial dimensions, and + it performs the operation separately for each channel of each batch item. + """ + # Create the Gaussian kernel or load it from the dictionary + if (kernel_size, sigma) not in gaussian_kernel_dict.keys(): + gaussian_kernel_dict[(kernel_size, sigma)] = gaussian_kernel( + kernel_size, sigma + ).to(seg.device) + g_kernel = gaussian_kernel_dict[(kernel_size, sigma)] + + # Add batch and channel dimensions + g_kernel = g_kernel.view(1, 1, *g_kernel.size()) + # Apply the Gaussian filter to each channel + padding = kernel_size // 2 + + # Move the kernel to the same device as the segmentation tensor + g_kernel = g_kernel.to(seg.device) + + # Apply the Gaussian filter + filtered_seg = F.conv3d(seg, g_kernel, padding=padding, groups=seg.shape[1]) + return filtered_seg + + +def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: + """ + Generate the skeleton of a ground truth segmentation. + + This function takes a ground truth segmentation `gt_seg`, smooths it using a + Gaussian filter, and then computes its soft skeleton using the `soft_skel` function. + + Intention: When using the binary ground truth segmentation for skeletonization, + the resulting skeleton is very patchy and not smooth. When using the smoothed + ground truth segmentation, the resulting skeleton is much smoother and more + accurate. + + Parameters + ---------- + gt_seg : torch.Tensor + A torch.Tensor representing the ground truth segmentation. + Shape: (B, C, D, H, W) + iterations : int, optional + The number of iterations for skeletonization. Default is 5. + + Returns + ------- + torch.Tensor + A torch.Tensor representing the skeleton of the ground truth segmentation. + + Notes + ----- + - The input `gt_seg` should be a binary segmentation tensor where 1 represents the + object of interest. + - The function first smooths the `gt_seg` using a Gaussian filter to enhance the + object's structure. + - The skeletonization process is performed using the `soft_skel` function with the + specified number of iterations. + - The resulting skeleton is returned as a binary torch.Tensor where 1 indicates the + skeleton points. + """ + gt_smooth = ( + apply_gaussian_filter((gt_seg == 1) * 1.0, kernel_size=15, sigma=2.0) * 1.5 + ) + skel_gt = soft_skel(gt_smooth, iter_=iterations) + return skel_gt + + +def masked_surface_dice( + data: torch.Tensor, + target: torch.Tensor, + ignore_label: int = 2, + soft_skel_iterations: int = 3, + smooth: float = 3.0, + binary_prediction: bool = False, + reduction: str = "none", +) -> torch.Tensor: + """ + Compute the surface Dice loss with masking for ignore labels. + + The surface Dice loss measures the similarity between the predicted segmentation's + skeleton and the ground truth segmentation (and vice versa). Labels annotated with + "ignore_label" are ignored. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs representing the predicted segmentation. + Expected shape: (B, C, D, H, W) + target : torch.Tensor + Tensor of target labels representing the ground truth segmentation. + Expected shape: (B, 1, D, H, W) + ignore_label : int + The label value to be ignored when computing the loss. + soft_skel_iterations : int + Number of iterations for skeletonization in the underlying operations. + smooth : float + Smoothing factor to avoid division by zero. + binary_prediction : bool + If True, the predicted segmentation is assumed to be binary. Default is False. + reduction : str + Specifies the reduction to apply to the output. Default is "none". + + Returns + ------- + torch.Tensor + The calculated surface Dice loss. + """ + # Create a mask to ignore the specified label in the target + data = sigmoid(data) + mask = target != ignore_label + + # Compute soft skeletonization + if binary_prediction: + skel_pred = get_GT_skeleton(data.clone(), soft_skel_iterations) + else: + skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False) + skel_true = get_GT_skeleton(target.clone(), soft_skel_iterations) + + # Mask out ignore labels + skel_pred[~mask] = 0 + skel_true[~mask] = 0 + + # compute surface dice loss + tprec = ( + torch.sum(torch.multiply(skel_pred, target), dim=(1, 2, 3, 4)) + smooth + ) / (torch.sum(skel_pred, dim=(1, 2, 3, 4)) + smooth) + tsens = (torch.sum(torch.multiply(skel_true, data), dim=(1, 2, 3, 4)) + smooth) / ( + torch.sum(skel_true, dim=(1, 2, 3, 4)) + smooth + ) + surf_dice_loss = 2.0 * (tprec * tsens) / (tprec + tsens) + if reduction == "none": + return surf_dice_loss + elif reduction == "mean": + return torch.mean(surf_dice_loss) + + +class IgnoreLabelSurfaceDiceLoss(_Loss): + """ + Surface Dice loss, adding ignore labels. + + Parameters + ---------- + ignore_label : int + The label to ignore when calculating the loss. + reduction : str, optional + Specifies the reduction to apply to the output, by default "mean". + kwargs : dict + Additional keyword arguments. + """ + + def __init__( + self, + ignore_label: int, + soft_skel_iterations: int = 3, + smooth: float = 3.0, + **kwargs, + ) -> None: + super().__init__() + self.ignore_label = ignore_label + self.soft_skel_iterations = soft_skel_iterations + self.smooth = smooth + + def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute the loss. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs. + Expected shape: (B, C, D, H, W) + target : torch.Tensor + Tensor of target labels. + Expected shape: (B, 1, D, H, W) + + Returns + ------- + torch.Tensor + The calculated loss. + """ + # Create a mask to ignore the specified label in the target + surf_dice_score = masked_surface_dice( + data=data, + target=target, + ignore_label=self.ignore_label, + soft_skel_iterations=self.soft_skel_iterations, + smooth=self.smooth, + ) + surf_dice_loss = 1.0 - surf_dice_score + return surf_dice_loss diff --git a/src/membrain_seg/segmentation/training/training_param_summary.py b/src/membrain_seg/segmentation/training/training_param_summary.py new file mode 100644 index 0000000..67277c8 --- /dev/null +++ b/src/membrain_seg/segmentation/training/training_param_summary.py @@ -0,0 +1,117 @@ +def print_training_parameters( + data_dir: str = "", + log_dir: str = "logs/", + batch_size: int = 2, + num_workers: int = 8, + max_epochs: int = 1000, + aug_prob_to_one: bool = False, + use_deep_supervision: bool = False, + project_name: str = "membrain-seg_v0", + sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, +): + """ + Print a formatted overview of the training parameters with explanations. + + Parameters + ---------- + data_dir : str, optional + Path to the directory containing training data. + log_dir : str, optional + Path to the directory where logs should be stored. + batch_size : int, optional + Number of samples per batch of input data. + num_workers : int, optional + Number of subprocesses to use for data loading. + max_epochs : int, optional + Maximum number of epochs to train for. + aug_prob_to_one : bool, optional + If True, all augmentation probabilities are set to 1. + use_deep_supervision : bool, optional + If True, enables deep supervision in the U-Net model. + project_name : str, optional + Name of the project for logging purposes. + sub_name : str, optional + Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. + + Returns + ------- + None + """ + print("\033[1mTraining Parameters Overview:\033[0m\n") + print( + "Data Directory:\n '{}' \n Path to the directory containing " + "training data.".format(data_dir) + ) + print("————————————————————————————————————————————————————————") + print( + "Log Directory:\n '{}' \n Directory where logs and outputs will " + "be stored.".format(log_dir) + ) + print("————————————————————————————————————————————————————————") + print( + "Batch Size:\n {} \n Number of samples processed in a single batch.".format( + batch_size + ) + ) + print("————————————————————————————————————————————————————————") + print( + "Number of Workers:\n {} \n Subprocesses to use for data " + "loading.".format(num_workers) + ) + print("————————————————————————————————————————————————————————") + print(f"Max Epochs:\n {max_epochs} \n Maximum number of training epochs.") + print("————————————————————————————————————————————————————————") + aug_status = "Enabled" if aug_prob_to_one else "Disabled" + print( + "Augmentation Probability to One:\n {} \n If enabled, sets all " + "augmentation probabilities to 1. (strong augmentation)".format(aug_status) + ) + print("————————————————————————————————————————————————————————") + deep_sup_status = "Enabled" if use_deep_supervision else "Disabled" + print( + "Use Deep Supervision:\n {} \n If enabled, activates deep " + "supervision in model.".format(deep_sup_status) + ) + print("————————————————————————————————————————————————————————") + print( + "Project Name:\n '{}' \n Name identifier for the current" + " training session.".format(project_name) + ) + print("————————————————————————————————————————————————————————") + print( + "Sub Name:\n '{}' \n Additional sub-identifier for organizing" + " outputs.".format(sub_name) + ) + print("————————————————————————————————————————————————————————") + surf_dice_status = "Enabled" if use_surf_dice else "Disabled" + print( + "Use Surface Dice:\n {} \n If enabled, includes Surface-Dice in the loss " + "calculation.".format(surf_dice_status) + ) + print("————————————————————————————————————————————————————————") + print( + "Surface Dice Weight:\n {} \n Weighting of the Surface-Dice" + " loss, if enabled.".format(surf_dice_weight) + ) + print("————————————————————————————————————————————————————————") + if surf_dice_tokens: + tokens = ", ".join(surf_dice_tokens) + print( + "Surface Dice Tokens:\n [{}] \n Specific tokens used for " + "Surface-Dice loss. Other tokens will be neglected.".format(tokens) + ) + else: + print( + "Surface Dice Tokens:\n None \n No specific tokens are used for " + "Surface-Dice loss." + ) + print("\n") diff --git a/tests/membrain_seg/training/test_optim_utils.py b/tests/membrain_seg/training/test_optim_utils.py index 4151d7a..dd0433c 100644 --- a/tests/membrain_seg/training/test_optim_utils.py +++ b/tests/membrain_seg/training/test_optim_utils.py @@ -12,6 +12,7 @@ def test_loss_fn_correctness(): import torch from membrain_seg.segmentation.training.optim_utils import ( + CombinedLoss, DeepSuperVisionLoss, IgnoreLabelDiceCELoss, ) @@ -73,17 +74,24 @@ def extend_labels(labels): pred_labels[2][pred_labels[2] < 0.0] = 0.0 ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + combined_loss = CombinedLoss( + losses=[ignore_dice_loss], weights=[1.0], loss_inclusion_tokens=["ds1"] + ) losses = test_ignore_dice_loss(ignore_dice_loss, pred_labels, gt_labels) assert losses[0] == losses[1] == losses[2] == losses[4] != losses[3] deep_supervision_loss = DeepSuperVisionLoss( - ignore_dice_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] + combined_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] ) gt_labels_ds = extend_labels(gt_labels) ds_losses = [] for pred_label in pred_labels: pred_labels_ds = extend_labels(pred_label) - ds_losses.append(deep_supervision_loss(pred_labels_ds, gt_labels_ds)) + ds_losses.append( + deep_supervision_loss( + pred_labels_ds, gt_labels_ds, ["ds1"] * len(gt_labels_ds) + ) + ) assert ds_losses[0] == ds_losses[1] == ds_losses[2] == ds_losses[4] != ds_losses[3] print("All ignore loss assertions passed.")