diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..791b9c9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,15 @@ +* membrain-seg version: +* Python version: +* Operating System: + +### Description + +Describe what you were trying to get done. +Tell us what happened, what went wrong, and what you expected to happen. + +### What I Did + +``` +Paste the command(s) you ran and the output. +If there was a crash, please include the traceback here. +``` diff --git a/.github/TEST_FAIL_TEMPLATE.md b/.github/TEST_FAIL_TEMPLATE.md new file mode 100644 index 0000000..3512972 --- /dev/null +++ b/.github/TEST_FAIL_TEMPLATE.md @@ -0,0 +1,12 @@ +--- +title: "{{ env.TITLE }}" +labels: [bug] +--- +The {{ workflow }} workflow failed on {{ date | date("YYYY-MM-DD HH:mm") }} UTC + +The most recent failing test was on {{ env.PLATFORM }} py{{ env.PYTHON }} +with commit: {{ sha }} + +Full run: https://github.com/{{ repo }}/actions/runs/{{ env.RUN_ID }} + +(This post will be updated if another test fails, as long as this issue remains open.) diff --git a/.github/workflows/build-and-deploy-docs.yml b/.github/workflows/build-and-deploy-docs.yml index 20892fe..2de9d12 100644 --- a/.github/workflows/build-and-deploy-docs.yml +++ b/.github/workflows/build-and-deploy-docs.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 71ec914..6818144 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,13 +9,21 @@ on: pull_request: workflow_dispatch: schedule: - - cron: "0 0 * * 0" # every week (for --pre release tests) + # run every week (for --pre release tests) + - cron: "0 0 * * 0" + +# cancel in-progress runs that use the same workflow and branch +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: check-manifest: + # check-manifest is a tool that checks that all files in version control are + # included in the sdist (unless explicitly excluded) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - run: pipx run check-manifest test: @@ -24,8 +32,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10'] - platform: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11"] + platform: [ubuntu-latest] #, macos-latest, windows-latest] steps: - name: Cancel Previous Runs @@ -33,7 +41,7 @@ jobs: with: access_token: ${{ github.token }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -42,17 +50,17 @@ jobs: cache-dependency-path: "pyproject.toml" cache: "pip" - # if running a cron job, we add the --pre flag to test against pre-releases - - name: Install dependencies + - name: Install Dependencies run: | python -m pip install -U pip - python -m pip install -e .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }} + # if running a cron job, we add the --pre flag to test against pre-releases + python -m pip install .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }} - - name: Test + - name: ๐Ÿงช Run Tests run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing - # If something goes wrong, we can open an issue in the repo - - name: Report --pre Failures + # If something goes wrong with --pre tests, we can open an issue in the repo + - name: ๐Ÿ“ Report --pre Failures if: failure() && github.event_name == 'schedule' uses: JasonEtco/create-an-issue@v2 env: @@ -60,7 +68,7 @@ jobs: PLATFORM: ${{ matrix.platform }} PYTHON: ${{ matrix.python-version }} RUN_ID: ${{ github.run_id }} - TITLE: '[test-bot] pip install --pre is failing' + TITLE: "[test-bot] pip install --pre is failing" with: filename: .github/TEST_FAIL_TEMPLATE.md update_existing: true @@ -74,28 +82,32 @@ jobs: if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule' runs-on: ubuntu-latest + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing on PyPi + # see https://docs.pypi.org/trusted-publishers/ + id-token: write + # This permission allows writing releases + contents: write + steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - name: Set up Python + - name: ๐Ÿ Set up Python uses: actions/setup-python@v4 with: python-version: "3.x" - - name: install + - name: ๐Ÿ‘ท Build run: | - git tag - pip install -U pip build twine + python -m pip install build python -m build - twine check dist/* - ls -lh dist - - name: Build and publish - run: twine upload dist/* - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + - name: ๐Ÿšข Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.TWINE_API_KEY }} - uses: softprops/action-gh-release@v1 with: generate_release_notes: true + files: './dist/*' diff --git a/docs/installation.md b/docs/installation.md index d2f6a9f..cd47231 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -52,6 +52,9 @@ This should display the different options you can choose from MemBrain, like "se ## Step 5: Download pre-trained segmentation model (optional) We recommend to use denoised (ideally Cryo-CARE1) tomograms for segmentation. However, our current best model is available for download [here](https://drive.google.com/file/d/1tSQIz_UCsQZNfyHg0RxD-4meFgolszo8/view?usp=sharing) and should also work on non-denoised data. Please let us know how it works for you. + +NOTE: Previous model files are not compatible with MONAI v1.3.0 or higher. So if you're using v1.3.0 or higher, consider downgrading to MONAI v1.2.0 or downloading this [adapted version](https://drive.google.com/file/d/1Tfg2Ju-cgSj_71_b1gVMnjqNYea7L1Hm/view?usp=sharing) of our most recent model file. + If the given model does not work properly, you may want to try one of our previous versions: Other (older) model versions: @@ -65,3 +68,16 @@ Once downloaded, you can use it in MemBrain-seg's [Segmentation](./Usage/Segment ``` [1] T. -O. Buchholz, M. Jordan, G. Pigino and F. Jug, "Cryo-CARE: Content-Aware Image Restoration for Cryo-Transmission Electron Microscopy Data," 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019), Venice, Italy, 2019, pp. 502-506, doi: 10.1109/ISBI.2019.8759519. ``` + + +# Troubleshooting +Here is a collection of common issues and how to fix them: + +- `RuntimeError: The NVIDIA driver on your system is too old (found version 11070). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has +been compiled with your version of the CUDA driver.` + + The latest Pytorch versions require higher CUDA versions that may not be installed on your system yet. You can either install the new CUDA version or (maybe easier) downgrade Pytorch to a version that is compatible: + + `pip uninstall torch` + + `pip install torch==2.0.1` \ No newline at end of file diff --git a/src/membrain_seg/annotations/extract_patch_cli.py b/src/membrain_seg/annotations/extract_patch_cli.py index 339e476..a5074f2 100644 --- a/src/membrain_seg/annotations/extract_patch_cli.py +++ b/src/membrain_seg/annotations/extract_patch_cli.py @@ -21,6 +21,11 @@ def extract_patches( help="Path to the folder where extracted patches should be stored. \ (subdirectories will be created)", ), + ds_token: str = Option( # noqa: B008 + "other", + help="Dataset token. Important for distinguishing between different \ + datasets. Should NOT contain underscores!", + ), coords_file: str = Option( # noqa: B008 None, help="Path to a file containing coordinates for patch extraction. The file \ @@ -93,6 +98,7 @@ def extract_patches( coords=coords, out_dir=out_folder, idx_add=idx_add, + ds_token=ds_token, token=token, pad_value=pad_value, ) diff --git a/src/membrain_seg/annotations/extract_patches.py b/src/membrain_seg/annotations/extract_patches.py index d9628ab..7e5ff78 100644 --- a/src/membrain_seg/annotations/extract_patches.py +++ b/src/membrain_seg/annotations/extract_patches.py @@ -51,7 +51,7 @@ def pad_labels(patch, padding, pad_value=2.0): def get_out_files_and_patch_number( - token, out_folder_raw, out_folder_lab, patch_nr, idx_add + ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add ): """ Create filenames and corrected patch numbers. @@ -62,8 +62,10 @@ def get_out_files_and_patch_number( Parameters ---------- + ds_token : str + The dataset identifier used as a part of the filename. token : str - The unique identifier used as a part of the filename. + The tomogram identifier used as a part of the filename. out_folder_raw : str The directory path where raw data patches are stored. out_folder_lab : str @@ -96,27 +98,34 @@ def get_out_files_and_patch_number( """ patch_nr += idx_add out_file_patch = os.path.join( - out_folder_raw, token + "_patch" + str(patch_nr) + "_raw.nii.gz" + out_folder_raw, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz" ) out_file_patch_label = os.path.join( - out_folder_lab, token + "_patch" + str(patch_nr) + "_labels.nii.gz" + out_folder_lab, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz" ) exist_add = 0 while os.path.isfile(out_file_patch): exist_add += 1 out_file_patch = os.path.join( out_folder_raw, - token + "_patch" + str(patch_nr + exist_add) + "_raw.nii.gz", + ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz", ) out_file_patch_label = os.path.join( out_folder_lab, - token + "_patch" + str(patch_nr + exist_add) + "_labels.nii.gz", + ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz", ) return patch_nr + exist_add, out_file_patch, out_file_patch_label def extract_patches( - tomo_path, seg_path, coords, out_dir, idx_add=0, token=None, pad_value=2.0 + tomo_path, + seg_path, + coords, + out_dir, + ds_token="other", + token=None, + idx_add=0, + pad_value=2.0, ): """ Extracts 3D patches from a given tomogram and corresponding segmentation. @@ -133,11 +142,13 @@ def extract_patches( List of tuples where each tuple represents the 3D coordinates of a patch center. out_dir : str The output directory where the extracted patches will be saved. - idx_add : int, optional - The index addition for patch numbering, default is 0. + ds_token : str, optional + Dataset token to uniquely identify the dataset, default is 'other'. token : str, optional Token to uniquely identify the tomogram, default is None. If None, the base name of the tomogram file path is used. + idx_add : int, optional + The index addition for patch numbering, default is 0. pad_value: float, optional Borders of extracted patch are padded with this value ("ignore" label) @@ -170,7 +181,7 @@ def extract_patches( for patch_nr, cur_coords in enumerate(coords): patch_nr, out_file_patch, out_file_patch_label = get_out_files_and_patch_number( - token, out_folder_raw, out_folder_lab, patch_nr, idx_add + ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add ) print("Extracting patch nr", patch_nr, "from tomo", token) try: diff --git a/src/membrain_seg/annotations/merge_corrections.py b/src/membrain_seg/annotations/merge_corrections.py index 361c64a..e2c0d68 100644 --- a/src/membrain_seg/annotations/merge_corrections.py +++ b/src/membrain_seg/annotations/merge_corrections.py @@ -46,13 +46,15 @@ def get_corrections_from_folder(folder_name, orig_pred_file): or filename.startswith("Ignore") or filename.startswith("ignore") ): - print("ATTENTION! Not processing", filename) - print("Is this intended?") + print( + "File does not fit into Add/Remove/Ignore naming! " "Not processing", + filename, + ) continue readdata = sitk.GetArrayFromImage( sitk.ReadImage(os.path.join(folder_name, filename)) ) - print("Adding file", filename, "<--") + print("Adding file", filename) if filename.startswith("Add") or filename.startswith("add"): add_patch += readdata diff --git a/src/membrain_seg/segmentation/cli/segment_cli.py b/src/membrain_seg/segmentation/cli/segment_cli.py index 5125e8e..b96cbec 100644 --- a/src/membrain_seg/segmentation/cli/segment_cli.py +++ b/src/membrain_seg/segmentation/cli/segment_cli.py @@ -78,7 +78,7 @@ def segment( @cli.command(name="components", no_args_is_help=True) def components( segmentation_path: str = Option( # noqa: B008 - help="Path to the membrane segmentation to be processed.", **PKWARGS + ..., help="Path to the membrane segmentation to be processed.", **PKWARGS ), out_folder: str = Option( # noqa: B008 "./predictions", help="Path to the folder where segmentations should be stored." @@ -114,7 +114,7 @@ def components( @cli.command(name="thresholds", no_args_is_help=True) def thresholds( scoremap_path: str = Option( # noqa: B008 - help="Path to the membrane scoremap to be processed.", **PKWARGS + ..., help="Path to the membrane scoremap to be processed.", **PKWARGS ), thresholds: List[float] = Option( # noqa: B008 ..., diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py index ac3d394..8cc6132 100644 --- a/src/membrain_seg/segmentation/cli/train_cli.py +++ b/src/membrain_seg/segmentation/cli/train_cli.py @@ -1,4 +1,7 @@ +from typing import List, Optional + from typer import Option +from typing_extensions import Annotated from ..train import train as _train from .cli import OPTION_PROMPT_KWARGS as PKWARGS @@ -70,7 +73,7 @@ def train_advanced( help="Batch size for training.", ), num_workers: int = Option( # noqa: B008 - 1, + 8, help="Number of worker threads for loading data", ), max_epochs: int = Option( # noqa: B008 @@ -84,6 +87,22 @@ def train_advanced( but also severely increases training time.\ Pass "True" or "False".', ), + use_surface_dice: bool = Option( # noqa: B008 + False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' + ), + surface_dice_weight: float = Option( # noqa: B008 + 1.0, help="Scaling factor for the Surface-Dice loss. " + ), + surface_dice_tokens: Annotated[ + Optional[List[str]], + Option( + help='List of tokens to \ + use for the Surface-Dice loss. \ + Pass tokens separately:\ + For example, train_advanced --surface_dice_tokens "ds1" \ + --surface_dice_tokens "ds2"' + ), + ] = None, use_deep_supervision: bool = Option( # noqa: B008 True, help='Whether to use deep supervision. Pass "True" or "False".' ), @@ -119,6 +138,12 @@ def train_advanced( If set to False, data augmentation still happens, but not as frequently. More data augmentation can lead to a better performance, but also increases the training time substantially. + use_surface_dice : bool + Determines whether to use Surface-Dice loss, by default True. + surface_dice_weight : float + Scaling factor for the Surface-Dice loss, by default 1.0. + surface_dice_tokens : list + List of tokens to use for the Surface-Dice loss, by default ["all"]. use_deep_supervision : bool Determines whether to use deep supervision, by default True. project_name : str @@ -140,6 +165,9 @@ def train_advanced( max_epochs=max_epochs, aug_prob_to_one=aug_prob_to_one, use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surface_dice, + surf_dice_weight=surface_dice_weight, + surf_dice_tokens=surface_dice_tokens, project_name=project_name, sub_name=sub_name, ) diff --git a/src/membrain_seg/segmentation/dataloading/data_utils.py b/src/membrain_seg/segmentation/dataloading/data_utils.py index 6662708..b306f9d 100644 --- a/src/membrain_seg/segmentation/dataloading/data_utils.py +++ b/src/membrain_seg/segmentation/dataloading/data_utils.py @@ -175,7 +175,8 @@ def store_segmented_tomograms( out_folder = out_folder if store_probabilities: out_file = os.path.join( - out_folder, os.path.basename(orig_data_path)[:-4] + "_scores.mrc" + out_folder, + os.path.splitext(os.path.basename(orig_data_path))[0] + "_scores.mrc", ) out_tomo = Tomogram( data=predictions_np, header=mrc_header, voxel_size=voxel_size @@ -186,7 +187,10 @@ def store_segmented_tomograms( ) out_file_thres = os.path.join( out_folder, - os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc", + os.path.splitext(os.path.basename(orig_data_path))[0] + + "_" + + ckpt_token + + "_segmented.mrc", ) if store_connected_components: predictions_np_thres = connected_components( diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py index d856883..9b66dda 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py @@ -1,7 +1,6 @@ import os from typing import Dict -# from skimage import io import imageio as io import numpy as np from torch.utils.data import Dataset @@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: "label": np.expand_dims(self.labels[idx], 0), } idx_dict = self.transforms(idx_dict) + idx_dict["dataset"] = self.dataset_labels[idx] return idx_dict def __len__(self) -> int: @@ -126,6 +126,7 @@ def load_data(self) -> None: print("Loading images into dataset.") self.imgs = [] self.labels = [] + self.dataset_labels = [] for entry in self.data_paths: label = read_nifti( entry[1] @@ -137,6 +138,7 @@ def load_data(self) -> None: img = np.transpose(img, (1, 2, 0)) self.imgs.append(img) self.labels.append(label) + self.dataset_labels.append(get_dataset_token(entry[0])) def initialize_imgs_paths(self) -> None: """ @@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None: os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"), test_sample["label"][1][0, :, :, num_mask], ) + + +def get_dataset_token(patch_name): + """ + Get the dataset token from the patch name. + + Parameters + ---------- + patch_name : str + The name of the patch. + + Returns + ------- + str + The dataset token. + + """ + basename = os.path.basename(patch_name) + dataset_token = basename.split("_")[0] + return dataset_token diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 7a67310..711f1d1 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -8,17 +8,13 @@ from monai.transforms import AsDiscrete, Compose, EnsureType, Lambda from ..training.metric_utils import masked_accuracy, threshold_function - -# from monai.networks.nets import UNet as MonaiUnet -# The normal Monai DynUNet upsamples low-resolution layers to compare directly to GT -# My implementation leaves them in low resolution and compares to down-sampled GT -# Not sure which implementation is better -# To be discussed with Alister & Kevin from ..training.optim_utils import ( + CombinedLoss, DeepSuperVisionLoss, - DynUNetDirectDeepSupervision, # I like to use deep supervision + DynUNetDirectDeepSupervision, IgnoreLabelDiceCELoss, ) +from ..training.surface_dice import IgnoreLabelSurfaceDiceLoss, masked_surface_dice class SemanticSegmentationUnet(pl.LightningModule): @@ -62,6 +58,12 @@ class SemanticSegmentationUnet(pl.LightningModule): The maximum number of epochs for training. use_deep_supervision : bool, default=False Whether to use deep supervision. + use_surf_dice : bool, default=False + Whether to use surface dice loss. + surf_dice_weight : float, default=1.0 + The weight for the surface dice loss. + surf_dice_tokens : list, default=[] + The tokens for which to compute the surface dice loss. """ @@ -80,6 +82,9 @@ def __init__( roi_size: Tuple[int, ...] = (160, 160, 160), max_epochs: int = 1000, use_deep_supervision: bool = False, + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, ): super().__init__() @@ -102,16 +107,39 @@ def __init__( upsample_kernel_size=(1, 2, 2, 2, 2, 2), filters=channels, res_block=True, - # norm_name="INSTANCE", - # norm=Norm.INSTANCE, # I like the instance normalization better than - # batchnorm in this case, as we will probably have - # only small batch sizes, making BN more noisy deep_supervision=True, deep_supr_num=2, ) - ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + + ### Build up loss function + losses = [] + weights = [] + loss_inclusion_tokens = [] + ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="none") + losses.append(ignore_dice_loss) + weights.append(1.0) + loss_inclusion_tokens.append(["all"]) # Apply to every element + + if use_surf_dice: + if surf_dice_tokens is None: + surf_dice_tokens = ["all"] + ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss( + ignore_label=2, soft_skel_iterations=5 + ) + losses.append(ignore_surf_dice_loss) + weights.append(surf_dice_weight) + loss_inclusion_tokens.append(surf_dice_tokens) + + scaled_weights = [entry / sum(weights) for entry in weights] + + loss_function = CombinedLoss( + losses=losses, + weights=scaled_weights, + loss_inclusion_tokens=loss_inclusion_tokens, + ) + self.loss_function = DeepSuperVisionLoss( - ignore_dice_loss, + loss_function, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] if use_deep_supervision else [1.0, 0.0, 0.0, 0.0, 0.0], @@ -143,7 +171,9 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] self.running_train_acc = 0.0 + self.running_train_surf_dice = 0.0 self.running_val_acc = 0.0 + self.running_val_surf_dice = 0.0 def forward(self, x) -> torch.Tensor: """Implementation of the forward pass. @@ -180,9 +210,9 @@ def training_step( See the pytorch-lightning module documentation for details. """ - images, labels = batch["image"], batch["label"] + images, labels, ds_label = batch["image"], batch["label"], batch["dataset"] output = self.forward(images) - loss = self.loss_function(output, labels) + loss = self.loss_function(output, labels, ds_label) stats_dict = {"train_loss": loss, "train_number": output[0].shape[0]} self.training_step_outputs.append(stats_dict) @@ -190,6 +220,17 @@ def training_step( masked_accuracy(output[0], labels[0], ignore_label=2.0, threshold_value=0.0) * output[0].shape[0] ) + self.running_train_surf_dice += ( + masked_surface_dice( + data=output[0].detach(), + target=labels[0].detach(), + ignore_label=2.0, + soft_skel_iterations=5, + smooth=1.0, + reduction="mean", + ) + * output[0].shape[0] + ) return {"loss": loss} @@ -207,13 +248,17 @@ def on_train_epoch_end(self): mean_train_loss = torch.tensor(train_loss / num_items) mean_train_acc = self.running_train_acc / num_items + mean_train_surf_dice = self.running_train_surf_dice / num_items self.running_train_acc = 0.0 - self.log("train_loss", mean_train_loss) # , batch_size=num_items) - self.log("train_acc", mean_train_acc) # , batch_size=num_items) + self.running_train_surf_dice = 0.0 + self.log("train_loss", mean_train_loss) + self.log("train_acc", mean_train_acc) + self.log("train_surf_dice", mean_train_surf_dice) self.training_step_outputs = [] print("EPOCH Training loss", mean_train_loss.item()) print("EPOCH Training acc", mean_train_acc.item()) + print("EPOCH Training surface dice", mean_train_surf_dice.item()) # Accuracy not the most informative metric, but a good sanity check return {"train_loss": mean_train_loss} @@ -224,13 +269,9 @@ def validation_step(self, batch, batch_idx): using a sliding window. See the pytorch-lightning module documentation for details. """ - images, labels = batch[self.image_key], batch[self.label_key] - # sw_batch_size = 4 - # outputs = sliding_window_inference( - # images, self.roi_size, sw_batch_size, self.forward - # ) + images, labels, ds_label = batch["image"], batch["label"], batch["dataset"] outputs = self.forward(images) - loss = self.loss_function(outputs, labels) + loss = self.loss_function(outputs, labels, ds_label) # Cloning and adjusting preds & labels for Dice. # Could also use the same labels, but maybe we want to @@ -254,6 +295,18 @@ def validation_step(self, batch, batch_idx): ) * outputs[0].shape[0] ) + + self.running_val_surf_dice += ( + masked_surface_dice( + data=outputs[0].detach(), + target=labels[0].detach(), + ignore_label=2.0, + soft_skel_iterations=5, + smooth=1.0, + reduction="mean", + ) + * outputs[0].shape[0] + ) return stats_dict def on_validation_epoch_end(self): @@ -270,13 +323,17 @@ def on_validation_epoch_end(self): mean_val_loss = torch.tensor(val_loss / num_items) mean_val_acc = self.running_val_acc / num_items + mean_val_surf_dice = self.running_val_surf_dice / num_items self.running_val_acc = 0.0 - self.log("val_loss", mean_val_loss), # batch_size=num_items) - self.log("val_dice", mean_val_dice) # , batch_size=num_items) + self.running_val_surf_dice = 0.0 + self.log("val_loss", mean_val_loss), + self.log("val_dice", mean_val_dice) + self.log("val_surf_dice", mean_val_surf_dice) self.log("val_accuracy", mean_val_acc) self.validation_step_outputs = [] print("EPOCH Validation loss", mean_val_loss.item()) print("EPOCH Validation dice", mean_val_dice) + print("EPOCH Validation surface dice", mean_val_surf_dice.item()) print("EPOCH Validation acc", mean_val_acc.item()) return {"val_loss": mean_val_loss, "val_metric": mean_val_dice} diff --git a/src/membrain_seg/segmentation/segment.py b/src/membrain_seg/segmentation/segment.py index e44a327..f5ccda0 100644 --- a/src/membrain_seg/segmentation/segment.py +++ b/src/membrain_seg/segmentation/segment.py @@ -78,8 +78,9 @@ def segment( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize the model and load trained weights from checkpoint - pl_model = SemanticSegmentationUnet() - pl_model = pl_model.load_from_checkpoint(model_checkpoint, map_location=device) + pl_model = SemanticSegmentationUnet.load_from_checkpoint( + model_checkpoint, map_location=device, strict=False + ) pl_model.to(device) # Preprocess the new data diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 2f077f2..9c576f5 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -9,6 +9,9 @@ MemBrainSegDataModule, ) from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.training_param_summary import ( + print_training_parameters, +) warnings.filterwarnings("ignore", category=UserWarning, module="torch._tensor") warnings.filterwarnings("ignore", category=UserWarning, module="monai.data") @@ -24,6 +27,9 @@ def train( use_deep_supervision: bool = False, project_name: str = "membrain-seg_v0", sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, ): """ Train the model on the specified data. @@ -52,11 +58,31 @@ def train( Name of the project for logging purposes. sub_name : str, optional Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. Returns ------- None """ + print_training_parameters( + data_dir=data_dir, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, + ) # Set up the data module data_module = MemBrainSegDataModule( data_dir=data_dir, @@ -67,15 +93,16 @@ def train( # Set up the model model = SemanticSegmentationUnet( - max_epochs=max_epochs, use_deep_supervision=use_deep_supervision + max_epochs=max_epochs, + use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, ) project_name = project_name checkpointing_name = project_name + "_" + sub_name # Set up logging - wandb_logger = pl_loggers.WandbLogger( - project=project_name, log_model=False, save_code=True - ) csv_logger = pl_loggers.CSVLogger(log_dir) # Set up model checkpointing @@ -106,7 +133,7 @@ def on_epoch_start(self, trainer, pl_module): # Set up the trainer trainer = pl.Trainer( precision="16-mixed", - logger=[csv_logger, wandb_logger], + logger=[csv_logger], callbacks=[ checkpoint_callback_val_loss, checkpoint_callback_regular, diff --git a/src/membrain_seg/segmentation/training/metric_utils.py b/src/membrain_seg/segmentation/training/metric_utils.py index f30a258..ad8f300 100644 --- a/src/membrain_seg/segmentation/training/metric_utils.py +++ b/src/membrain_seg/segmentation/training/metric_utils.py @@ -34,7 +34,7 @@ def masked_accuracy( mask = ( y_gt == ignore_label if ignore_label is not None - else torch.ones_like(y_gt).bool() + else torch.zeros_like(y_gt).bool() ) acc = (threshold_function(y_pred, threshold_value=threshold_value) == y_gt).float() acc[mask] = 0.0 diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index 8dbdc6b..a16e136 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss): def __init__( self, ignore_label: int, - reduction: str = "mean", + reduction: str = "none", lambda_dice: float = 1.0, lambda_ce: float = 1.0, **kwargs, @@ -95,11 +95,31 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: orig_data, target_tensor, reduction="none" ) bce_loss[~mask] = 0.0 - bce_loss = torch.sum(bce_loss) / torch.sum(mask) - dice_loss = self.dice_loss(data, target, mask) + # TODO: Check if this is correct: I adjusted the loss to be + # computed per batch element + bce_loss = torch.sum(bce_loss, dim=(1, 2, 3, 4)) / torch.sum( + mask, dim=(1, 2, 3, 4) + ) + # Compute Dice loss separately for each batch element + dice_loss = torch.zeros_like(bce_loss) + for batch_idx in range(data.shape[0]): + dice_loss[batch_idx] = self.dice_loss( + data[batch_idx].unsqueeze(0), + target[batch_idx].unsqueeze(0), + mask[batch_idx].unsqueeze(0), + ) # Combine the Dice and Cross Entropy losses combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss + if self.reduction == "mean": + combined_loss = combined_loss.mean() + elif self.reduction == "sum": + combined_loss = combined_loss.sum() + else: + raise ValueError( + f"Invalid reduction type {self.reduction}. " + "Valid options are 'mean' and 'sum'." + ) return combined_loss @@ -134,7 +154,7 @@ def __init__( self.loss_fn = loss_fn self.weights = weights - def forward(self, inputs: list, targets: list) -> torch.Tensor: + def forward(self, inputs: list, targets: list, ds_labels: list) -> torch.Tensor: """ Compute the loss. @@ -144,6 +164,8 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: List of tensors of model outputs. targets : list List of tensors of target labels. + ds_labels : list + List of dataset labels for each batch element. Returns ------- @@ -151,6 +173,96 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor: The calculated loss. """ loss = 0.0 - for weight, data, target in zip(self.weights, inputs, targets): - loss += weight * self.loss_fn(data, target) + ds_labels_loop = [ds_labels] * 5 + for weight, data, target, ds_label in zip( + self.weights, inputs, targets, ds_labels_loop + ): + loss += weight * self.loss_fn(data, target, ds_label) + return loss + + +class CombinedLoss(_Loss): + """ + Combine multiple loss functions into a single one. + + Parameters + ---------- + losses : List[Callable] + A list of loss function instances. + weights : List[float] + List of weights corresponding to each loss function (must + be of same length as losses). + loss_inclusion_tokens : List[List[str]] + A list of lists containing tokens for each loss function. + Each sublist corresponds to a loss function and contains + tokens for which the loss should be included. + If the list contains "all", then the loss will be included + for all cases. + + Notes + ----- + IMPORTANT: Loss functions need to return a tensors containing the + loss for each batch element. + + The loss_exclusion_tokens parameter is used to exclude certain + cases from the loss calculation. For example, if the loss_exclusion_tokens + parameter is [["ds1", "ds2"], ["ds1"]], then the first loss function + will be excluded for cases where the dataset label is "ds1" or "ds2", + and the second loss function will be excluded for cases where the + dataset label is "ds1". + """ + + def __init__( + self, + losses: list, + weights: list, + loss_inclusion_tokens: list, + **kwargs, + ) -> None: + super().__init__() + self.losses = losses + self.weights = weights + self.loss_inclusion_tokens = loss_inclusion_tokens + + def forward( + self, data: torch.Tensor, target: torch.Tensor, ds_label: list + ) -> torch.Tensor: + """ + Compute the combined loss. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs. + target : torch.Tensor + Tensor of target labels. + ds_label : List[str] + List of dataset labels for each batch element. + + Returns + ------- + torch.Tensor + The calculated combined loss. + """ + loss = 0.0 + for loss_idx, (cur_loss, cur_weight) in enumerate( + zip(self.losses, self.weights) + ): + cur_loss_val = cur_loss(data, target) + + # Zero out losses for excluded cases + for batch_idx, ds_lab in enumerate(ds_label): + if ( + "all" in self.loss_inclusion_tokens[loss_idx] + or ds_lab in self.loss_inclusion_tokens[loss_idx] + ): + continue + cur_loss_val[batch_idx] = 0.0 + + # Aggregate loss + cur_loss_val = cur_loss_val.sum() / ((cur_loss_val != 0.0).sum() + 1e-3) + loss += cur_weight * cur_loss_val + + # Normalize loss + loss = loss / sum(self.weights) return loss diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py new file mode 100644 index 0000000..bc2c6b4 --- /dev/null +++ b/src/membrain_seg/segmentation/training/surface_dice.py @@ -0,0 +1,458 @@ +""" +Surface Dice implementation. + +Adapted from: clDice - A Novel Topology-Preserving Loss Function for Tubular +Structure Segmentation +Original Authors: Johannes C. Paetzold and Suprosanna Shit +Sources: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/ + soft_skeleton.py + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py +License: MIT License. + +The following code is a modification of the original clDice implementation. +Modifications were made to include additional functionality and integrate +with new project requirements. The original license and copyright notice are +provided below. + +MIT License + +Copyright (c) 2021 Johannes C. Paetzold and Suprosanna Shit + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import math + +import torch +import torch.nn.functional as F +from torch.nn.functional import sigmoid +from torch.nn.modules.loss import _Loss + + +def soft_erode(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: + """ + Apply soft erosion operation to the input image. + + Soft erosion is achieved by applying a min-pooling operation to the input image. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W) + separate_pool : bool, optional + If True, perform separate 3D max-pooling operations along different axes. + Default is False. + + Returns + ------- + torch.Tensor + Eroded image tensor with the same shape as the input. + + Raises + ------ + ValueError + If the input tensor has an unsupported number of dimensions. + + Notes + ----- + - The soft erosion can be performed with separate 3D min-pooling operations + along different axes if separate_pool is True, or with a single + 3D min-pooling operation with a kernel of size (3, 3, 3) if + separate_pool is False. + """ + assert len(img.shape) == 5 + if separate_pool: + p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) + p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) + p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) + return torch.min(torch.min(p1, p2), p3) + p4 = -F.max_pool3d(-img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) + return p4 + + +def soft_dilate(img: torch.Tensor) -> torch.Tensor: + """ + Apply soft dilation operation to the input image. + + Soft dilation is achieved by applying a max-pooling operation to the input image. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + + Returns + ------- + torch.Tensor + Dilated image tensor with the same shape as the input. + + Raises + ------ + ValueError + If the input tensor has an unsupported number of dimensions. + + Notes + ----- + - For 5D input, the soft dilation is performed using a 3D max-pooling operation + with a kernel of size (3, 3, 3). + """ + assert len(img.shape) == 5 + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) + + +def soft_open(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor: + """ + Apply soft opening operation to the input image. + + Soft opening is achieved by applying soft erosion followed by soft dilation. + The intention of soft opening is to remove thin membranes from the segmentation. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + separate_pool : bool, optional + If True, perform separate erosion and dilation operations. Default is False. + + Returns + ------- + torch.Tensor + Opened image tensor with the same shape as the input. + + Notes + ----- + - Soft opening is performed by applying soft erosion followed by soft dilation + to the input image. + - For 5D input, separate erosion and dilation can be performed if separate_pool + is True. + """ + return soft_dilate(soft_erode(img, separate_pool=separate_pool)) + + +def soft_skel( + img: torch.Tensor, iter_: int, separate_pool: bool = False +) -> torch.Tensor: + """ + Compute the soft skeleton of the input image. + + The skeleton is computed by applying soft erosion iteratively to the input image. + In each iteration, the difference between the input image and the "opened" image is + computed and added to the skeleton. + + Reasoning: if there is a difference between the input image and the "opened" image, + there must be a thin membrane skeleton in the input image that was removed by the + opening operation. + + Parameters + ---------- + img : torch.Tensor + Input image tensor with shape (B, C, D, H, W). + iter_ : int + Number of iterations for skeletonization. + separate_pool : bool, optional + If True, perform separate erosion and dilation operations. + Default is False. + + Returns + ------- + torch.Tensor + Soft skeleton image tensor with the same shape as the input. + + Notes + ----- + - Separate erosion can be performed if separate_pool is True. + """ + img1 = soft_open(img, separate_pool=separate_pool) + skel = F.relu(img - img1) + for _j in range(iter_): + img = soft_erode(img) + img1 = soft_open(img, separate_pool=separate_pool) + delta = F.relu(img - img1) + skel = skel + F.relu(delta - skel * delta) + return skel + + +def gaussian_kernel(size: int, sigma: float) -> torch.Tensor: + """ + Creates a 3D Gaussian kernel using the specified size and sigma. + + Parameters + ---------- + size : int + The size of the Gaussian kernel. It determines the length of + each dimension of the cube. + sigma : float + The standard deviation of the Gaussian kernel. It controls + the spread of the Gaussian. + + Returns + ------- + torch.Tensor + A 3D tensor representing the Gaussian kernel. + + Notes + ----- + The function creates a Gaussian kernel, which is essentially a + cube of dimensions [size, size, size]. Each entry in the cube is + computed using the Gaussian function based on its distance from the center. + The kernel is normalized so that its total sum equals 1. + """ + # Define a coordinate grid centered at (0,0,0) + grid = torch.arange(size, dtype=torch.float32) - (size - 1) / 2 + # Create a 3D meshgrid + x, y, z = torch.meshgrid(grid, grid, grid) + xyz_grid = torch.stack([x, y, z], dim=-1) + + # Calculate the 3D Gaussian kernel + gaussian_kernel = torch.exp(-torch.sum(xyz_grid**2, dim=-1) / (2 * sigma**2)) + gaussian_kernel /= (2 * math.pi * sigma**2) ** (3 / 2) # Normalize + + # Ensure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel + + +gaussian_kernel_dict = {} +""" Not sure why, but moving the gaussian kernel to GPU takes surprisingly long.+ +So we precompute it, store it on GPU, and reuse it. +""" + + +def apply_gaussian_filter( + seg: torch.Tensor, kernel_size: int, sigma: float +) -> torch.Tensor: + """ + Apply a Gaussian filter to a segmentation tensor using PyTorch. + + This function convolves the input tensor with a Gaussian kernel. + The function creates or retrieves a Gaussian kernel based on the + specified size and standard deviation, and applies 3D convolution to each + channel of each batch item with appropriate padding to maintain spatial + dimensions. + + Parameters + ---------- + seg : torch.Tensor + The input segmentation tensor of shape (batch, channel, X, Y, Z). + kernel_size : int + The size of the Gaussian kernel, determining the length of each + dimension of the cube. + sigma : float + The standard deviation of the Gaussian kernel, controlling the spread. + + Returns + ------- + torch.Tensor + The filtered segmentation tensor of the same shape as input. + + Notes + ----- + This function uses a precomputed dictionary to enhance performance by + storing Gaussian kernels. If a kernel with the specified size and standard + deviation does not exist in the dictionary, it is created and added. The + function assumes the input tensor is a 5D tensor, applies 3D convolution + using the Gaussian kernel with padding to maintain spatial dimensions, and + it performs the operation separately for each channel of each batch item. + """ + # Create the Gaussian kernel or load it from the dictionary + if (kernel_size, sigma) not in gaussian_kernel_dict.keys(): + gaussian_kernel_dict[(kernel_size, sigma)] = gaussian_kernel( + kernel_size, sigma + ).to(seg.device) + g_kernel = gaussian_kernel_dict[(kernel_size, sigma)] + + # Add batch and channel dimensions + g_kernel = g_kernel.view(1, 1, *g_kernel.size()) + # Apply the Gaussian filter to each channel + padding = kernel_size // 2 + + # Move the kernel to the same device as the segmentation tensor + g_kernel = g_kernel.to(seg.device) + + # Apply the Gaussian filter + filtered_seg = F.conv3d(seg, g_kernel, padding=padding, groups=seg.shape[1]) + return filtered_seg + + +def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor: + """ + Generate the skeleton of a ground truth segmentation. + + This function takes a ground truth segmentation `gt_seg`, smooths it using a + Gaussian filter, and then computes its soft skeleton using the `soft_skel` function. + + Intention: When using the binary ground truth segmentation for skeletonization, + the resulting skeleton is very patchy and not smooth. When using the smoothed + ground truth segmentation, the resulting skeleton is much smoother and more + accurate. + + Parameters + ---------- + gt_seg : torch.Tensor + A torch.Tensor representing the ground truth segmentation. + Shape: (B, C, D, H, W) + iterations : int, optional + The number of iterations for skeletonization. Default is 5. + + Returns + ------- + torch.Tensor + A torch.Tensor representing the skeleton of the ground truth segmentation. + + Notes + ----- + - The input `gt_seg` should be a binary segmentation tensor where 1 represents the + object of interest. + - The function first smooths the `gt_seg` using a Gaussian filter to enhance the + object's structure. + - The skeletonization process is performed using the `soft_skel` function with the + specified number of iterations. + - The resulting skeleton is returned as a binary torch.Tensor where 1 indicates the + skeleton points. + """ + gt_smooth = ( + apply_gaussian_filter((gt_seg == 1) * 1.0, kernel_size=15, sigma=2.0) * 1.5 + ) + skel_gt = soft_skel(gt_smooth, iter_=iterations) + return skel_gt + + +def masked_surface_dice( + data: torch.Tensor, + target: torch.Tensor, + ignore_label: int = 2, + soft_skel_iterations: int = 3, + smooth: float = 3.0, + binary_prediction: bool = False, + reduction: str = "none", +) -> torch.Tensor: + """ + Compute the surface Dice loss with masking for ignore labels. + + The surface Dice loss measures the similarity between the predicted segmentation's + skeleton and the ground truth segmentation (and vice versa). Labels annotated with + "ignore_label" are ignored. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs representing the predicted segmentation. + Expected shape: (B, C, D, H, W) + target : torch.Tensor + Tensor of target labels representing the ground truth segmentation. + Expected shape: (B, 1, D, H, W) + ignore_label : int + The label value to be ignored when computing the loss. + soft_skel_iterations : int + Number of iterations for skeletonization in the underlying operations. + smooth : float + Smoothing factor to avoid division by zero. + binary_prediction : bool + If True, the predicted segmentation is assumed to be binary. Default is False. + reduction : str + Specifies the reduction to apply to the output. Default is "none". + + Returns + ------- + torch.Tensor + The calculated surface Dice loss. + """ + # Create a mask to ignore the specified label in the target + data = sigmoid(data) + mask = target != ignore_label + + # Compute soft skeletonization + if binary_prediction: + skel_pred = get_GT_skeleton(data.clone(), soft_skel_iterations) + else: + skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False) + skel_true = get_GT_skeleton(target.clone(), soft_skel_iterations) + + # Mask out ignore labels + skel_pred[~mask] = 0 + skel_true[~mask] = 0 + + # compute surface dice loss + tprec = ( + torch.sum(torch.multiply(skel_pred, target), dim=(1, 2, 3, 4)) + smooth + ) / (torch.sum(skel_pred, dim=(1, 2, 3, 4)) + smooth) + tsens = (torch.sum(torch.multiply(skel_true, data), dim=(1, 2, 3, 4)) + smooth) / ( + torch.sum(skel_true, dim=(1, 2, 3, 4)) + smooth + ) + surf_dice_loss = 2.0 * (tprec * tsens) / (tprec + tsens) + if reduction == "none": + return surf_dice_loss + elif reduction == "mean": + return torch.mean(surf_dice_loss) + + +class IgnoreLabelSurfaceDiceLoss(_Loss): + """ + Surface Dice loss, adding ignore labels. + + Parameters + ---------- + ignore_label : int + The label to ignore when calculating the loss. + reduction : str, optional + Specifies the reduction to apply to the output, by default "mean". + kwargs : dict + Additional keyword arguments. + """ + + def __init__( + self, + ignore_label: int, + soft_skel_iterations: int = 3, + smooth: float = 3.0, + **kwargs, + ) -> None: + super().__init__() + self.ignore_label = ignore_label + self.soft_skel_iterations = soft_skel_iterations + self.smooth = smooth + + def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute the loss. + + Parameters + ---------- + data : torch.Tensor + Tensor of model outputs. + Expected shape: (B, C, D, H, W) + target : torch.Tensor + Tensor of target labels. + Expected shape: (B, 1, D, H, W) + + Returns + ------- + torch.Tensor + The calculated loss. + """ + # Create a mask to ignore the specified label in the target + surf_dice_score = masked_surface_dice( + data=data, + target=target, + ignore_label=self.ignore_label, + soft_skel_iterations=self.soft_skel_iterations, + smooth=self.smooth, + ) + surf_dice_loss = 1.0 - surf_dice_score + return surf_dice_loss diff --git a/src/membrain_seg/segmentation/training/training_param_summary.py b/src/membrain_seg/segmentation/training/training_param_summary.py new file mode 100644 index 0000000..67277c8 --- /dev/null +++ b/src/membrain_seg/segmentation/training/training_param_summary.py @@ -0,0 +1,117 @@ +def print_training_parameters( + data_dir: str = "", + log_dir: str = "logs/", + batch_size: int = 2, + num_workers: int = 8, + max_epochs: int = 1000, + aug_prob_to_one: bool = False, + use_deep_supervision: bool = False, + project_name: str = "membrain-seg_v0", + sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, +): + """ + Print a formatted overview of the training parameters with explanations. + + Parameters + ---------- + data_dir : str, optional + Path to the directory containing training data. + log_dir : str, optional + Path to the directory where logs should be stored. + batch_size : int, optional + Number of samples per batch of input data. + num_workers : int, optional + Number of subprocesses to use for data loading. + max_epochs : int, optional + Maximum number of epochs to train for. + aug_prob_to_one : bool, optional + If True, all augmentation probabilities are set to 1. + use_deep_supervision : bool, optional + If True, enables deep supervision in the U-Net model. + project_name : str, optional + Name of the project for logging purposes. + sub_name : str, optional + Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. + + Returns + ------- + None + """ + print("\033[1mTraining Parameters Overview:\033[0m\n") + print( + "Data Directory:\n '{}' \n Path to the directory containing " + "training data.".format(data_dir) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print( + "Log Directory:\n '{}' \n Directory where logs and outputs will " + "be stored.".format(log_dir) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print( + "Batch Size:\n {} \n Number of samples processed in a single batch.".format( + batch_size + ) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print( + "Number of Workers:\n {} \n Subprocesses to use for data " + "loading.".format(num_workers) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print(f"Max Epochs:\n {max_epochs} \n Maximum number of training epochs.") + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + aug_status = "Enabled" if aug_prob_to_one else "Disabled" + print( + "Augmentation Probability to One:\n {} \n If enabled, sets all " + "augmentation probabilities to 1. (strong augmentation)".format(aug_status) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + deep_sup_status = "Enabled" if use_deep_supervision else "Disabled" + print( + "Use Deep Supervision:\n {} \n If enabled, activates deep " + "supervision in model.".format(deep_sup_status) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print( + "Project Name:\n '{}' \n Name identifier for the current" + " training session.".format(project_name) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print( + "Sub Name:\n '{}' \n Additional sub-identifier for organizing" + " outputs.".format(sub_name) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + surf_dice_status = "Enabled" if use_surf_dice else "Disabled" + print( + "Use Surface Dice:\n {} \n If enabled, includes Surface-Dice in the loss " + "calculation.".format(surf_dice_status) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + print( + "Surface Dice Weight:\n {} \n Weighting of the Surface-Dice" + " loss, if enabled.".format(surf_dice_weight) + ) + print("โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”") + if surf_dice_tokens: + tokens = ", ".join(surf_dice_tokens) + print( + "Surface Dice Tokens:\n [{}] \n Specific tokens used for " + "Surface-Dice loss. Other tokens will be neglected.".format(tokens) + ) + else: + print( + "Surface Dice Tokens:\n None \n No specific tokens are used for " + "Surface-Dice loss." + ) + print("\n") diff --git a/tests/membrain_seg/training/test_optim_utils.py b/tests/membrain_seg/training/test_optim_utils.py index 4151d7a..dd0433c 100644 --- a/tests/membrain_seg/training/test_optim_utils.py +++ b/tests/membrain_seg/training/test_optim_utils.py @@ -12,6 +12,7 @@ def test_loss_fn_correctness(): import torch from membrain_seg.segmentation.training.optim_utils import ( + CombinedLoss, DeepSuperVisionLoss, IgnoreLabelDiceCELoss, ) @@ -73,17 +74,24 @@ def extend_labels(labels): pred_labels[2][pred_labels[2] < 0.0] = 0.0 ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") + combined_loss = CombinedLoss( + losses=[ignore_dice_loss], weights=[1.0], loss_inclusion_tokens=["ds1"] + ) losses = test_ignore_dice_loss(ignore_dice_loss, pred_labels, gt_labels) assert losses[0] == losses[1] == losses[2] == losses[4] != losses[3] deep_supervision_loss = DeepSuperVisionLoss( - ignore_dice_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] + combined_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675] ) gt_labels_ds = extend_labels(gt_labels) ds_losses = [] for pred_label in pred_labels: pred_labels_ds = extend_labels(pred_label) - ds_losses.append(deep_supervision_loss(pred_labels_ds, gt_labels_ds)) + ds_losses.append( + deep_supervision_loss( + pred_labels_ds, gt_labels_ds, ["ds1"] * len(gt_labels_ds) + ) + ) assert ds_losses[0] == ds_losses[1] == ds_losses[2] == ds_losses[4] != ds_losses[3] print("All ignore loss assertions passed.")