From 67417a444e21197c59579e8339b9d01d10a15d4f Mon Sep 17 00:00:00 2001
From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com>
Date: Wed, 1 Nov 2023 09:02:50 +0100
Subject: [PATCH 1/8] Fix lightning issue (#41)
* Fix model loading issue with new lightning versions
* Adjust code style
* Adjust to precommit style
* add download link for MONAI v1.3.0 model
* add troubleshooting section
---
docs/installation.md | 16 ++++++++++++++++
src/membrain_seg/segmentation/segment.py | 5 +++--
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/docs/installation.md b/docs/installation.md
index d2f6a9f..cd47231 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -52,6 +52,9 @@ This should display the different options you can choose from MemBrain, like "se
## Step 5: Download pre-trained segmentation model (optional)
We recommend to use denoised (ideally Cryo-CARE1) tomograms for segmentation. However, our current best model is available for download [here](https://drive.google.com/file/d/1tSQIz_UCsQZNfyHg0RxD-4meFgolszo8/view?usp=sharing) and should also work on non-denoised data. Please let us know how it works for you.
+
+NOTE: Previous model files are not compatible with MONAI v1.3.0 or higher. So if you're using v1.3.0 or higher, consider downgrading to MONAI v1.2.0 or downloading this [adapted version](https://drive.google.com/file/d/1Tfg2Ju-cgSj_71_b1gVMnjqNYea7L1Hm/view?usp=sharing) of our most recent model file.
+
If the given model does not work properly, you may want to try one of our previous versions:
Other (older) model versions:
@@ -65,3 +68,16 @@ Once downloaded, you can use it in MemBrain-seg's [Segmentation](./Usage/Segment
```
[1] T. -O. Buchholz, M. Jordan, G. Pigino and F. Jug, "Cryo-CARE: Content-Aware Image Restoration for Cryo-Transmission Electron Microscopy Data," 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019), Venice, Italy, 2019, pp. 502-506, doi: 10.1109/ISBI.2019.8759519.
```
+
+
+# Troubleshooting
+Here is a collection of common issues and how to fix them:
+
+- `RuntimeError: The NVIDIA driver on your system is too old (found version 11070). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has
+been compiled with your version of the CUDA driver.`
+
+ The latest Pytorch versions require higher CUDA versions that may not be installed on your system yet. You can either install the new CUDA version or (maybe easier) downgrade Pytorch to a version that is compatible:
+
+ `pip uninstall torch`
+
+ `pip install torch==2.0.1`
\ No newline at end of file
diff --git a/src/membrain_seg/segmentation/segment.py b/src/membrain_seg/segmentation/segment.py
index e44a327..f5ccda0 100644
--- a/src/membrain_seg/segmentation/segment.py
+++ b/src/membrain_seg/segmentation/segment.py
@@ -78,8 +78,9 @@ def segment(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the model and load trained weights from checkpoint
- pl_model = SemanticSegmentationUnet()
- pl_model = pl_model.load_from_checkpoint(model_checkpoint, map_location=device)
+ pl_model = SemanticSegmentationUnet.load_from_checkpoint(
+ model_checkpoint, map_location=device, strict=False
+ )
pl_model.to(device)
# Preprocess the new data
From d76e97e2c51d920b8a3a932f3aa2862ce681481f Mon Sep 17 00:00:00 2001
From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com>
Date: Wed, 1 Nov 2023 09:51:15 +0100
Subject: [PATCH 2/8] Add default arguments to components and threshold
functions (#42)
---
src/membrain_seg/segmentation/cli/segment_cli.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/membrain_seg/segmentation/cli/segment_cli.py b/src/membrain_seg/segmentation/cli/segment_cli.py
index 5125e8e..b96cbec 100644
--- a/src/membrain_seg/segmentation/cli/segment_cli.py
+++ b/src/membrain_seg/segmentation/cli/segment_cli.py
@@ -78,7 +78,7 @@ def segment(
@cli.command(name="components", no_args_is_help=True)
def components(
segmentation_path: str = Option( # noqa: B008
- help="Path to the membrane segmentation to be processed.", **PKWARGS
+ ..., help="Path to the membrane segmentation to be processed.", **PKWARGS
),
out_folder: str = Option( # noqa: B008
"./predictions", help="Path to the folder where segmentations should be stored."
@@ -114,7 +114,7 @@ def components(
@cli.command(name="thresholds", no_args_is_help=True)
def thresholds(
scoremap_path: str = Option( # noqa: B008
- help="Path to the membrane scoremap to be processed.", **PKWARGS
+ ..., help="Path to the membrane scoremap to be processed.", **PKWARGS
),
thresholds: List[float] = Option( # noqa: B008
...,
From 1f9747ad2e6cdc2c19d8dfb12bdd8fa67ca01bb4 Mon Sep 17 00:00:00 2001
From: Ricardo Righetto
Date: Thu, 14 Dec 2023 16:34:52 +0100
Subject: [PATCH 3/8] Fixed output filename formatting which assumed the
extension was always 3 characters long (#43)
---
src/membrain_seg/segmentation/dataloading/data_utils.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/src/membrain_seg/segmentation/dataloading/data_utils.py b/src/membrain_seg/segmentation/dataloading/data_utils.py
index 6662708..b306f9d 100644
--- a/src/membrain_seg/segmentation/dataloading/data_utils.py
+++ b/src/membrain_seg/segmentation/dataloading/data_utils.py
@@ -175,7 +175,8 @@ def store_segmented_tomograms(
out_folder = out_folder
if store_probabilities:
out_file = os.path.join(
- out_folder, os.path.basename(orig_data_path)[:-4] + "_scores.mrc"
+ out_folder,
+ os.path.splitext(os.path.basename(orig_data_path))[0] + "_scores.mrc",
)
out_tomo = Tomogram(
data=predictions_np, header=mrc_header, voxel_size=voxel_size
@@ -186,7 +187,10 @@ def store_segmented_tomograms(
)
out_file_thres = os.path.join(
out_folder,
- os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc",
+ os.path.splitext(os.path.basename(orig_data_path))[0]
+ + "_"
+ + ckpt_token
+ + "_segmented.mrc",
)
if store_connected_components:
predictions_np_thres = connected_components(
From 40bc5e7ccffbe9cd18d98597ab3fe4fa07fa9527 Mon Sep 17 00:00:00 2001
From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com>
Date: Wed, 17 Jan 2024 11:32:07 +0100
Subject: [PATCH 4/8] update ci for deployment (#46)
Co-authored-by: Lorenz Lamm
---
.github/ISSUE_TEMPLATE.md | 15 ++++++++
.github/TEST_FAIL_TEMPLATE.md | 12 +++++++
.github/workflows/ci.yml | 67 +++++++++++++++++++----------------
3 files changed, 64 insertions(+), 30 deletions(-)
create mode 100644 .github/ISSUE_TEMPLATE.md
create mode 100644 .github/TEST_FAIL_TEMPLATE.md
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
new file mode 100644
index 0000000..791b9c9
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE.md
@@ -0,0 +1,15 @@
+* membrain-seg version:
+* Python version:
+* Operating System:
+
+### Description
+
+Describe what you were trying to get done.
+Tell us what happened, what went wrong, and what you expected to happen.
+
+### What I Did
+
+```
+Paste the command(s) you ran and the output.
+If there was a crash, please include the traceback here.
+```
diff --git a/.github/TEST_FAIL_TEMPLATE.md b/.github/TEST_FAIL_TEMPLATE.md
new file mode 100644
index 0000000..3512972
--- /dev/null
+++ b/.github/TEST_FAIL_TEMPLATE.md
@@ -0,0 +1,12 @@
+---
+title: "{{ env.TITLE }}"
+labels: [bug]
+---
+The {{ workflow }} workflow failed on {{ date | date("YYYY-MM-DD HH:mm") }} UTC
+
+The most recent failing test was on {{ env.PLATFORM }} py{{ env.PYTHON }}
+with commit: {{ sha }}
+
+Full run: https://github.com/{{ repo }}/actions/runs/{{ env.RUN_ID }}
+
+(This post will be updated if another test fails, as long as this issue remains open.)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 701d160..986e54f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -9,13 +9,21 @@ on:
pull_request:
workflow_dispatch:
schedule:
- - cron: "0 0 * * 0" # every week (for --pre release tests)
+ # run every week (for --pre release tests)
+ - cron: "0 0 * * 0"
+
+# cancel in-progress runs that use the same workflow and branch
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
jobs:
check-manifest:
+ # check-manifest is a tool that checks that all files in version control are
+ # included in the sdist (unless explicitly excluded)
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- run: pipx run check-manifest
test:
@@ -24,35 +32,30 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ['3.8', '3.9', '3.10']
- platform: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ platform: [ubuntu-latest] #, macos-latest, windows-latest]
steps:
- - name: Cancel Previous Runs
- uses: styfle/cancel-workflow-action@0.11.0
- with:
- access_token: ${{ github.token }}
-
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- - name: Set up Python ${{ matrix.python-version }}
+ - name: ๐ Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache-dependency-path: "pyproject.toml"
cache: "pip"
- # if running a cron job, we add the --pre flag to test against pre-releases
- - name: Install dependencies
+ - name: Install Dependencies
run: |
python -m pip install -U pip
- python -m pip install -e .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }}
+ # if running a cron job, we add the --pre flag to test against pre-releases
+ python -m pip install .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }}
- - name: Test
+ - name: ๐งช Run Tests
run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing
- # If something goes wrong, we can open an issue in the repo
- - name: Report --pre Failures
+ # If something goes wrong with --pre tests, we can open an issue in the repo
+ - name: ๐ Report --pre Failures
if: failure() && github.event_name == 'schedule'
uses: JasonEtco/create-an-issue@v2
env:
@@ -60,7 +63,7 @@ jobs:
PLATFORM: ${{ matrix.platform }}
PYTHON: ${{ matrix.python-version }}
RUN_ID: ${{ github.run_id }}
- TITLE: '[test-bot] pip install --pre is failing'
+ TITLE: "[test-bot] pip install --pre is failing"
with:
filename: .github/TEST_FAIL_TEMPLATE.md
update_existing: true
@@ -74,28 +77,32 @@ jobs:
if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule'
runs-on: ubuntu-latest
+ permissions:
+ # IMPORTANT: this permission is mandatory for trusted publishing on PyPi
+ # see https://docs.pypi.org/trusted-publishers/
+ id-token: write
+ # This permission allows writing releases
+ contents: write
+
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
- - name: Set up Python
+ - name: ๐ Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
- - name: install
+ - name: ๐ท Build
run: |
- git tag
- pip install -U pip build twine
+ python -m pip install build
python -m build
- twine check dist/*
- ls -lh dist
- - name: Build and publish
- run: twine upload dist/*
- env:
- TWINE_USERNAME: __token__
- TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }}
+ - name: ๐ข Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
- uses: softprops/action-gh-release@v1
with:
generate_release_notes: true
+ files: './dist/*'
From 11cd2055fc99315fdbd26527b52c64461b530aa7 Mon Sep 17 00:00:00 2001
From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com>
Date: Wed, 17 Jan 2024 11:43:03 +0100
Subject: [PATCH 5/8] remove 3.12 from test matrix (#47)
Co-authored-by: Lorenz Lamm
---
.github/workflows/ci.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 986e54f..7078ab4 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -32,7 +32,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.9", "3.10", "3.11"]
platform: [ubuntu-latest] #, macos-latest, windows-latest]
steps:
From 3fd4de900ad398b88a86d2cc83ffc7cb887cafd6 Mon Sep 17 00:00:00 2001
From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com>
Date: Wed, 17 Jan 2024 12:02:46 +0100
Subject: [PATCH 6/8] Add twine config to release section (#48)
Co-authored-by: Lorenz Lamm
---
.github/workflows/ci.yml | 2 ++
1 file changed, 2 insertions(+)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 7078ab4..c13bf4f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -101,6 +101,8 @@ jobs:
- name: ๐ข Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ password: ${{ secrets.TWINE_API_KEY }}
- uses: softprops/action-gh-release@v1
with:
From 49aa7980d224d5c59490bfcff48a1b2ce3706efc Mon Sep 17 00:00:00 2001
From: LorenzLamm <34575029+LorenzLamm@users.noreply.github.com>
Date: Mon, 22 Jan 2024 18:45:57 +0100
Subject: [PATCH 7/8] Surface dice loss (#45)
* Surface-Dice functionalities
* Adjust losses to be compatible with Surface-Dice exclusions
* Adjust training routine and include surface dice loss
* Add dataset labels to dataloading
* Pass Surface-Dice arguments to training routine
* Update CLI to include advanced options for Surface-Dice
* precommit formatting
* make list readable by passing argument multiple times
* remove redundant import
* Compatibility with updated masked_surface_dice function
* Add training summary and remove wandb logging
* remove reduntant print statements and include ds_labels into for-loop
* Implement Gaussian smoothing with torch to compute everything on GPU
* Training summary printing
* add dataset token to CLI
* Add dataset token to filename
* Update warnings
* Fix bug for accuracy masking
* Fix Dice reduction to scalar
* Make test compatible with CombinedLoss
* Fix default path
* Raise Error when reduction is not defined
* Add required dimensions to docstrings
---
.../annotations/extract_patch_cli.py | 6 +
.../annotations/extract_patches.py | 31 +-
.../annotations/merge_corrections.py | 8 +-
.../segmentation/cli/train_cli.py | 30 +-
.../dataloading/memseg_dataset.py | 24 +-
.../segmentation/networks/unet.py | 107 +++-
src/membrain_seg/segmentation/train.py | 37 +-
.../segmentation/training/metric_utils.py | 2 +-
.../segmentation/training/optim_utils.py | 124 ++++-
.../segmentation/training/surface_dice.py | 458 ++++++++++++++++++
.../training/training_param_summary.py | 117 +++++
.../membrain_seg/training/test_optim_utils.py | 12 +-
12 files changed, 902 insertions(+), 54 deletions(-)
create mode 100644 src/membrain_seg/segmentation/training/surface_dice.py
create mode 100644 src/membrain_seg/segmentation/training/training_param_summary.py
diff --git a/src/membrain_seg/annotations/extract_patch_cli.py b/src/membrain_seg/annotations/extract_patch_cli.py
index 339e476..a5074f2 100644
--- a/src/membrain_seg/annotations/extract_patch_cli.py
+++ b/src/membrain_seg/annotations/extract_patch_cli.py
@@ -21,6 +21,11 @@ def extract_patches(
help="Path to the folder where extracted patches should be stored. \
(subdirectories will be created)",
),
+ ds_token: str = Option( # noqa: B008
+ "other",
+ help="Dataset token. Important for distinguishing between different \
+ datasets. Should NOT contain underscores!",
+ ),
coords_file: str = Option( # noqa: B008
None,
help="Path to a file containing coordinates for patch extraction. The file \
@@ -93,6 +98,7 @@ def extract_patches(
coords=coords,
out_dir=out_folder,
idx_add=idx_add,
+ ds_token=ds_token,
token=token,
pad_value=pad_value,
)
diff --git a/src/membrain_seg/annotations/extract_patches.py b/src/membrain_seg/annotations/extract_patches.py
index d9628ab..7e5ff78 100644
--- a/src/membrain_seg/annotations/extract_patches.py
+++ b/src/membrain_seg/annotations/extract_patches.py
@@ -51,7 +51,7 @@ def pad_labels(patch, padding, pad_value=2.0):
def get_out_files_and_patch_number(
- token, out_folder_raw, out_folder_lab, patch_nr, idx_add
+ ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add
):
"""
Create filenames and corrected patch numbers.
@@ -62,8 +62,10 @@ def get_out_files_and_patch_number(
Parameters
----------
+ ds_token : str
+ The dataset identifier used as a part of the filename.
token : str
- The unique identifier used as a part of the filename.
+ The tomogram identifier used as a part of the filename.
out_folder_raw : str
The directory path where raw data patches are stored.
out_folder_lab : str
@@ -96,27 +98,34 @@ def get_out_files_and_patch_number(
"""
patch_nr += idx_add
out_file_patch = os.path.join(
- out_folder_raw, token + "_patch" + str(patch_nr) + "_raw.nii.gz"
+ out_folder_raw, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz"
)
out_file_patch_label = os.path.join(
- out_folder_lab, token + "_patch" + str(patch_nr) + "_labels.nii.gz"
+ out_folder_lab, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz"
)
exist_add = 0
while os.path.isfile(out_file_patch):
exist_add += 1
out_file_patch = os.path.join(
out_folder_raw,
- token + "_patch" + str(patch_nr + exist_add) + "_raw.nii.gz",
+ ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz",
)
out_file_patch_label = os.path.join(
out_folder_lab,
- token + "_patch" + str(patch_nr + exist_add) + "_labels.nii.gz",
+ ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz",
)
return patch_nr + exist_add, out_file_patch, out_file_patch_label
def extract_patches(
- tomo_path, seg_path, coords, out_dir, idx_add=0, token=None, pad_value=2.0
+ tomo_path,
+ seg_path,
+ coords,
+ out_dir,
+ ds_token="other",
+ token=None,
+ idx_add=0,
+ pad_value=2.0,
):
"""
Extracts 3D patches from a given tomogram and corresponding segmentation.
@@ -133,11 +142,13 @@ def extract_patches(
List of tuples where each tuple represents the 3D coordinates of a patch center.
out_dir : str
The output directory where the extracted patches will be saved.
- idx_add : int, optional
- The index addition for patch numbering, default is 0.
+ ds_token : str, optional
+ Dataset token to uniquely identify the dataset, default is 'other'.
token : str, optional
Token to uniquely identify the tomogram, default is None. If None,
the base name of the tomogram file path is used.
+ idx_add : int, optional
+ The index addition for patch numbering, default is 0.
pad_value: float, optional
Borders of extracted patch are padded with this value ("ignore" label)
@@ -170,7 +181,7 @@ def extract_patches(
for patch_nr, cur_coords in enumerate(coords):
patch_nr, out_file_patch, out_file_patch_label = get_out_files_and_patch_number(
- token, out_folder_raw, out_folder_lab, patch_nr, idx_add
+ ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add
)
print("Extracting patch nr", patch_nr, "from tomo", token)
try:
diff --git a/src/membrain_seg/annotations/merge_corrections.py b/src/membrain_seg/annotations/merge_corrections.py
index 361c64a..e2c0d68 100644
--- a/src/membrain_seg/annotations/merge_corrections.py
+++ b/src/membrain_seg/annotations/merge_corrections.py
@@ -46,13 +46,15 @@ def get_corrections_from_folder(folder_name, orig_pred_file):
or filename.startswith("Ignore")
or filename.startswith("ignore")
):
- print("ATTENTION! Not processing", filename)
- print("Is this intended?")
+ print(
+ "File does not fit into Add/Remove/Ignore naming! " "Not processing",
+ filename,
+ )
continue
readdata = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(folder_name, filename))
)
- print("Adding file", filename, "<--")
+ print("Adding file", filename)
if filename.startswith("Add") or filename.startswith("add"):
add_patch += readdata
diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py
index ac3d394..8cc6132 100644
--- a/src/membrain_seg/segmentation/cli/train_cli.py
+++ b/src/membrain_seg/segmentation/cli/train_cli.py
@@ -1,4 +1,7 @@
+from typing import List, Optional
+
from typer import Option
+from typing_extensions import Annotated
from ..train import train as _train
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
@@ -70,7 +73,7 @@ def train_advanced(
help="Batch size for training.",
),
num_workers: int = Option( # noqa: B008
- 1,
+ 8,
help="Number of worker threads for loading data",
),
max_epochs: int = Option( # noqa: B008
@@ -84,6 +87,22 @@ def train_advanced(
but also severely increases training time.\
Pass "True" or "False".',
),
+ use_surface_dice: bool = Option( # noqa: B008
+ False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".'
+ ),
+ surface_dice_weight: float = Option( # noqa: B008
+ 1.0, help="Scaling factor for the Surface-Dice loss. "
+ ),
+ surface_dice_tokens: Annotated[
+ Optional[List[str]],
+ Option(
+ help='List of tokens to \
+ use for the Surface-Dice loss. \
+ Pass tokens separately:\
+ For example, train_advanced --surface_dice_tokens "ds1" \
+ --surface_dice_tokens "ds2"'
+ ),
+ ] = None,
use_deep_supervision: bool = Option( # noqa: B008
True, help='Whether to use deep supervision. Pass "True" or "False".'
),
@@ -119,6 +138,12 @@ def train_advanced(
If set to False, data augmentation still happens, but not as frequently.
More data augmentation can lead to a better performance, but also increases the
training time substantially.
+ use_surface_dice : bool
+ Determines whether to use Surface-Dice loss, by default True.
+ surface_dice_weight : float
+ Scaling factor for the Surface-Dice loss, by default 1.0.
+ surface_dice_tokens : list
+ List of tokens to use for the Surface-Dice loss, by default ["all"].
use_deep_supervision : bool
Determines whether to use deep supervision, by default True.
project_name : str
@@ -140,6 +165,9 @@ def train_advanced(
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
+ use_surf_dice=use_surface_dice,
+ surf_dice_weight=surface_dice_weight,
+ surf_dice_tokens=surface_dice_tokens,
project_name=project_name,
sub_name=sub_name,
)
diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py
index d856883..9b66dda 100644
--- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py
+++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py
@@ -1,7 +1,6 @@
import os
from typing import Dict
-# from skimage import io
import imageio as io
import numpy as np
from torch.utils.data import Dataset
@@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"label": np.expand_dims(self.labels[idx], 0),
}
idx_dict = self.transforms(idx_dict)
+ idx_dict["dataset"] = self.dataset_labels[idx]
return idx_dict
def __len__(self) -> int:
@@ -126,6 +126,7 @@ def load_data(self) -> None:
print("Loading images into dataset.")
self.imgs = []
self.labels = []
+ self.dataset_labels = []
for entry in self.data_paths:
label = read_nifti(
entry[1]
@@ -137,6 +138,7 @@ def load_data(self) -> None:
img = np.transpose(img, (1, 2, 0))
self.imgs.append(img)
self.labels.append(label)
+ self.dataset_labels.append(get_dataset_token(entry[0]))
def initialize_imgs_paths(self) -> None:
"""
@@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None:
os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"),
test_sample["label"][1][0, :, :, num_mask],
)
+
+
+def get_dataset_token(patch_name):
+ """
+ Get the dataset token from the patch name.
+
+ Parameters
+ ----------
+ patch_name : str
+ The name of the patch.
+
+ Returns
+ -------
+ str
+ The dataset token.
+
+ """
+ basename = os.path.basename(patch_name)
+ dataset_token = basename.split("_")[0]
+ return dataset_token
diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py
index 7a67310..711f1d1 100644
--- a/src/membrain_seg/segmentation/networks/unet.py
+++ b/src/membrain_seg/segmentation/networks/unet.py
@@ -8,17 +8,13 @@
from monai.transforms import AsDiscrete, Compose, EnsureType, Lambda
from ..training.metric_utils import masked_accuracy, threshold_function
-
-# from monai.networks.nets import UNet as MonaiUnet
-# The normal Monai DynUNet upsamples low-resolution layers to compare directly to GT
-# My implementation leaves them in low resolution and compares to down-sampled GT
-# Not sure which implementation is better
-# To be discussed with Alister & Kevin
from ..training.optim_utils import (
+ CombinedLoss,
DeepSuperVisionLoss,
- DynUNetDirectDeepSupervision, # I like to use deep supervision
+ DynUNetDirectDeepSupervision,
IgnoreLabelDiceCELoss,
)
+from ..training.surface_dice import IgnoreLabelSurfaceDiceLoss, masked_surface_dice
class SemanticSegmentationUnet(pl.LightningModule):
@@ -62,6 +58,12 @@ class SemanticSegmentationUnet(pl.LightningModule):
The maximum number of epochs for training.
use_deep_supervision : bool, default=False
Whether to use deep supervision.
+ use_surf_dice : bool, default=False
+ Whether to use surface dice loss.
+ surf_dice_weight : float, default=1.0
+ The weight for the surface dice loss.
+ surf_dice_tokens : list, default=[]
+ The tokens for which to compute the surface dice loss.
"""
@@ -80,6 +82,9 @@ def __init__(
roi_size: Tuple[int, ...] = (160, 160, 160),
max_epochs: int = 1000,
use_deep_supervision: bool = False,
+ use_surf_dice: bool = False,
+ surf_dice_weight: float = 1.0,
+ surf_dice_tokens: list = None,
):
super().__init__()
@@ -102,16 +107,39 @@ def __init__(
upsample_kernel_size=(1, 2, 2, 2, 2, 2),
filters=channels,
res_block=True,
- # norm_name="INSTANCE",
- # norm=Norm.INSTANCE, # I like the instance normalization better than
- # batchnorm in this case, as we will probably have
- # only small batch sizes, making BN more noisy
deep_supervision=True,
deep_supr_num=2,
)
- ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean")
+
+ ### Build up loss function
+ losses = []
+ weights = []
+ loss_inclusion_tokens = []
+ ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="none")
+ losses.append(ignore_dice_loss)
+ weights.append(1.0)
+ loss_inclusion_tokens.append(["all"]) # Apply to every element
+
+ if use_surf_dice:
+ if surf_dice_tokens is None:
+ surf_dice_tokens = ["all"]
+ ignore_surf_dice_loss = IgnoreLabelSurfaceDiceLoss(
+ ignore_label=2, soft_skel_iterations=5
+ )
+ losses.append(ignore_surf_dice_loss)
+ weights.append(surf_dice_weight)
+ loss_inclusion_tokens.append(surf_dice_tokens)
+
+ scaled_weights = [entry / sum(weights) for entry in weights]
+
+ loss_function = CombinedLoss(
+ losses=losses,
+ weights=scaled_weights,
+ loss_inclusion_tokens=loss_inclusion_tokens,
+ )
+
self.loss_function = DeepSuperVisionLoss(
- ignore_dice_loss,
+ loss_function,
weights=[1.0, 0.5, 0.25, 0.125, 0.0675]
if use_deep_supervision
else [1.0, 0.0, 0.0, 0.0, 0.0],
@@ -143,7 +171,9 @@ def __init__(
self.training_step_outputs = []
self.validation_step_outputs = []
self.running_train_acc = 0.0
+ self.running_train_surf_dice = 0.0
self.running_val_acc = 0.0
+ self.running_val_surf_dice = 0.0
def forward(self, x) -> torch.Tensor:
"""Implementation of the forward pass.
@@ -180,9 +210,9 @@ def training_step(
See the pytorch-lightning module documentation for details.
"""
- images, labels = batch["image"], batch["label"]
+ images, labels, ds_label = batch["image"], batch["label"], batch["dataset"]
output = self.forward(images)
- loss = self.loss_function(output, labels)
+ loss = self.loss_function(output, labels, ds_label)
stats_dict = {"train_loss": loss, "train_number": output[0].shape[0]}
self.training_step_outputs.append(stats_dict)
@@ -190,6 +220,17 @@ def training_step(
masked_accuracy(output[0], labels[0], ignore_label=2.0, threshold_value=0.0)
* output[0].shape[0]
)
+ self.running_train_surf_dice += (
+ masked_surface_dice(
+ data=output[0].detach(),
+ target=labels[0].detach(),
+ ignore_label=2.0,
+ soft_skel_iterations=5,
+ smooth=1.0,
+ reduction="mean",
+ )
+ * output[0].shape[0]
+ )
return {"loss": loss}
@@ -207,13 +248,17 @@ def on_train_epoch_end(self):
mean_train_loss = torch.tensor(train_loss / num_items)
mean_train_acc = self.running_train_acc / num_items
+ mean_train_surf_dice = self.running_train_surf_dice / num_items
self.running_train_acc = 0.0
- self.log("train_loss", mean_train_loss) # , batch_size=num_items)
- self.log("train_acc", mean_train_acc) # , batch_size=num_items)
+ self.running_train_surf_dice = 0.0
+ self.log("train_loss", mean_train_loss)
+ self.log("train_acc", mean_train_acc)
+ self.log("train_surf_dice", mean_train_surf_dice)
self.training_step_outputs = []
print("EPOCH Training loss", mean_train_loss.item())
print("EPOCH Training acc", mean_train_acc.item())
+ print("EPOCH Training surface dice", mean_train_surf_dice.item())
# Accuracy not the most informative metric, but a good sanity check
return {"train_loss": mean_train_loss}
@@ -224,13 +269,9 @@ def validation_step(self, batch, batch_idx):
using a sliding window. See the pytorch-lightning
module documentation for details.
"""
- images, labels = batch[self.image_key], batch[self.label_key]
- # sw_batch_size = 4
- # outputs = sliding_window_inference(
- # images, self.roi_size, sw_batch_size, self.forward
- # )
+ images, labels, ds_label = batch["image"], batch["label"], batch["dataset"]
outputs = self.forward(images)
- loss = self.loss_function(outputs, labels)
+ loss = self.loss_function(outputs, labels, ds_label)
# Cloning and adjusting preds & labels for Dice.
# Could also use the same labels, but maybe we want to
@@ -254,6 +295,18 @@ def validation_step(self, batch, batch_idx):
)
* outputs[0].shape[0]
)
+
+ self.running_val_surf_dice += (
+ masked_surface_dice(
+ data=outputs[0].detach(),
+ target=labels[0].detach(),
+ ignore_label=2.0,
+ soft_skel_iterations=5,
+ smooth=1.0,
+ reduction="mean",
+ )
+ * outputs[0].shape[0]
+ )
return stats_dict
def on_validation_epoch_end(self):
@@ -270,13 +323,17 @@ def on_validation_epoch_end(self):
mean_val_loss = torch.tensor(val_loss / num_items)
mean_val_acc = self.running_val_acc / num_items
+ mean_val_surf_dice = self.running_val_surf_dice / num_items
self.running_val_acc = 0.0
- self.log("val_loss", mean_val_loss), # batch_size=num_items)
- self.log("val_dice", mean_val_dice) # , batch_size=num_items)
+ self.running_val_surf_dice = 0.0
+ self.log("val_loss", mean_val_loss),
+ self.log("val_dice", mean_val_dice)
+ self.log("val_surf_dice", mean_val_surf_dice)
self.log("val_accuracy", mean_val_acc)
self.validation_step_outputs = []
print("EPOCH Validation loss", mean_val_loss.item())
print("EPOCH Validation dice", mean_val_dice)
+ print("EPOCH Validation surface dice", mean_val_surf_dice.item())
print("EPOCH Validation acc", mean_val_acc.item())
return {"val_loss": mean_val_loss, "val_metric": mean_val_dice}
diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py
index 2f077f2..9c576f5 100644
--- a/src/membrain_seg/segmentation/train.py
+++ b/src/membrain_seg/segmentation/train.py
@@ -9,6 +9,9 @@
MemBrainSegDataModule,
)
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
+from membrain_seg.segmentation.training.training_param_summary import (
+ print_training_parameters,
+)
warnings.filterwarnings("ignore", category=UserWarning, module="torch._tensor")
warnings.filterwarnings("ignore", category=UserWarning, module="monai.data")
@@ -24,6 +27,9 @@ def train(
use_deep_supervision: bool = False,
project_name: str = "membrain-seg_v0",
sub_name: str = "1",
+ use_surf_dice: bool = False,
+ surf_dice_weight: float = 1.0,
+ surf_dice_tokens: list = None,
):
"""
Train the model on the specified data.
@@ -52,11 +58,31 @@ def train(
Name of the project for logging purposes.
sub_name : str, optional
Sub-name of the project for logging purposes.
+ use_surf_dice : bool, optional
+ If True, enables Surface-Dice loss.
+ surf_dice_weight : float, optional
+ Weight for the Surface-Dice loss.
+ surf_dice_tokens : list, optional
+ List of tokens to use for the Surface-Dice loss.
Returns
-------
None
"""
+ print_training_parameters(
+ data_dir=data_dir,
+ log_dir=log_dir,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ max_epochs=max_epochs,
+ aug_prob_to_one=aug_prob_to_one,
+ use_deep_supervision=use_deep_supervision,
+ project_name=project_name,
+ sub_name=sub_name,
+ use_surf_dice=use_surf_dice,
+ surf_dice_weight=surf_dice_weight,
+ surf_dice_tokens=surf_dice_tokens,
+ )
# Set up the data module
data_module = MemBrainSegDataModule(
data_dir=data_dir,
@@ -67,15 +93,16 @@ def train(
# Set up the model
model = SemanticSegmentationUnet(
- max_epochs=max_epochs, use_deep_supervision=use_deep_supervision
+ max_epochs=max_epochs,
+ use_deep_supervision=use_deep_supervision,
+ use_surf_dice=use_surf_dice,
+ surf_dice_weight=surf_dice_weight,
+ surf_dice_tokens=surf_dice_tokens,
)
project_name = project_name
checkpointing_name = project_name + "_" + sub_name
# Set up logging
- wandb_logger = pl_loggers.WandbLogger(
- project=project_name, log_model=False, save_code=True
- )
csv_logger = pl_loggers.CSVLogger(log_dir)
# Set up model checkpointing
@@ -106,7 +133,7 @@ def on_epoch_start(self, trainer, pl_module):
# Set up the trainer
trainer = pl.Trainer(
precision="16-mixed",
- logger=[csv_logger, wandb_logger],
+ logger=[csv_logger],
callbacks=[
checkpoint_callback_val_loss,
checkpoint_callback_regular,
diff --git a/src/membrain_seg/segmentation/training/metric_utils.py b/src/membrain_seg/segmentation/training/metric_utils.py
index f30a258..ad8f300 100644
--- a/src/membrain_seg/segmentation/training/metric_utils.py
+++ b/src/membrain_seg/segmentation/training/metric_utils.py
@@ -34,7 +34,7 @@ def masked_accuracy(
mask = (
y_gt == ignore_label
if ignore_label is not None
- else torch.ones_like(y_gt).bool()
+ else torch.zeros_like(y_gt).bool()
)
acc = (threshold_function(y_pred, threshold_value=threshold_value) == y_gt).float()
acc[mask] = 0.0
diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py
index 8dbdc6b..a16e136 100644
--- a/src/membrain_seg/segmentation/training/optim_utils.py
+++ b/src/membrain_seg/segmentation/training/optim_utils.py
@@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss):
def __init__(
self,
ignore_label: int,
- reduction: str = "mean",
+ reduction: str = "none",
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
**kwargs,
@@ -95,11 +95,31 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
orig_data, target_tensor, reduction="none"
)
bce_loss[~mask] = 0.0
- bce_loss = torch.sum(bce_loss) / torch.sum(mask)
- dice_loss = self.dice_loss(data, target, mask)
+ # TODO: Check if this is correct: I adjusted the loss to be
+ # computed per batch element
+ bce_loss = torch.sum(bce_loss, dim=(1, 2, 3, 4)) / torch.sum(
+ mask, dim=(1, 2, 3, 4)
+ )
+ # Compute Dice loss separately for each batch element
+ dice_loss = torch.zeros_like(bce_loss)
+ for batch_idx in range(data.shape[0]):
+ dice_loss[batch_idx] = self.dice_loss(
+ data[batch_idx].unsqueeze(0),
+ target[batch_idx].unsqueeze(0),
+ mask[batch_idx].unsqueeze(0),
+ )
# Combine the Dice and Cross Entropy losses
combined_loss = self.lambda_dice * dice_loss + self.lambda_ce * bce_loss
+ if self.reduction == "mean":
+ combined_loss = combined_loss.mean()
+ elif self.reduction == "sum":
+ combined_loss = combined_loss.sum()
+ else:
+ raise ValueError(
+ f"Invalid reduction type {self.reduction}. "
+ "Valid options are 'mean' and 'sum'."
+ )
return combined_loss
@@ -134,7 +154,7 @@ def __init__(
self.loss_fn = loss_fn
self.weights = weights
- def forward(self, inputs: list, targets: list) -> torch.Tensor:
+ def forward(self, inputs: list, targets: list, ds_labels: list) -> torch.Tensor:
"""
Compute the loss.
@@ -144,6 +164,8 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor:
List of tensors of model outputs.
targets : list
List of tensors of target labels.
+ ds_labels : list
+ List of dataset labels for each batch element.
Returns
-------
@@ -151,6 +173,96 @@ def forward(self, inputs: list, targets: list) -> torch.Tensor:
The calculated loss.
"""
loss = 0.0
- for weight, data, target in zip(self.weights, inputs, targets):
- loss += weight * self.loss_fn(data, target)
+ ds_labels_loop = [ds_labels] * 5
+ for weight, data, target, ds_label in zip(
+ self.weights, inputs, targets, ds_labels_loop
+ ):
+ loss += weight * self.loss_fn(data, target, ds_label)
+ return loss
+
+
+class CombinedLoss(_Loss):
+ """
+ Combine multiple loss functions into a single one.
+
+ Parameters
+ ----------
+ losses : List[Callable]
+ A list of loss function instances.
+ weights : List[float]
+ List of weights corresponding to each loss function (must
+ be of same length as losses).
+ loss_inclusion_tokens : List[List[str]]
+ A list of lists containing tokens for each loss function.
+ Each sublist corresponds to a loss function and contains
+ tokens for which the loss should be included.
+ If the list contains "all", then the loss will be included
+ for all cases.
+
+ Notes
+ -----
+ IMPORTANT: Loss functions need to return a tensors containing the
+ loss for each batch element.
+
+ The loss_exclusion_tokens parameter is used to exclude certain
+ cases from the loss calculation. For example, if the loss_exclusion_tokens
+ parameter is [["ds1", "ds2"], ["ds1"]], then the first loss function
+ will be excluded for cases where the dataset label is "ds1" or "ds2",
+ and the second loss function will be excluded for cases where the
+ dataset label is "ds1".
+ """
+
+ def __init__(
+ self,
+ losses: list,
+ weights: list,
+ loss_inclusion_tokens: list,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.losses = losses
+ self.weights = weights
+ self.loss_inclusion_tokens = loss_inclusion_tokens
+
+ def forward(
+ self, data: torch.Tensor, target: torch.Tensor, ds_label: list
+ ) -> torch.Tensor:
+ """
+ Compute the combined loss.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Tensor of model outputs.
+ target : torch.Tensor
+ Tensor of target labels.
+ ds_label : List[str]
+ List of dataset labels for each batch element.
+
+ Returns
+ -------
+ torch.Tensor
+ The calculated combined loss.
+ """
+ loss = 0.0
+ for loss_idx, (cur_loss, cur_weight) in enumerate(
+ zip(self.losses, self.weights)
+ ):
+ cur_loss_val = cur_loss(data, target)
+
+ # Zero out losses for excluded cases
+ for batch_idx, ds_lab in enumerate(ds_label):
+ if (
+ "all" in self.loss_inclusion_tokens[loss_idx]
+ or ds_lab in self.loss_inclusion_tokens[loss_idx]
+ ):
+ continue
+ cur_loss_val[batch_idx] = 0.0
+
+ # Aggregate loss
+ cur_loss_val = cur_loss_val.sum() / ((cur_loss_val != 0.0).sum() + 1e-3)
+ loss += cur_weight * cur_loss_val
+
+ # Normalize loss
+ loss = loss / sum(self.weights)
return loss
diff --git a/src/membrain_seg/segmentation/training/surface_dice.py b/src/membrain_seg/segmentation/training/surface_dice.py
new file mode 100644
index 0000000..bc2c6b4
--- /dev/null
+++ b/src/membrain_seg/segmentation/training/surface_dice.py
@@ -0,0 +1,458 @@
+"""
+Surface Dice implementation.
+
+Adapted from: clDice - A Novel Topology-Preserving Loss Function for Tubular
+Structure Segmentation
+Original Authors: Johannes C. Paetzold and Suprosanna Shit
+Sources: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/
+ soft_skeleton.py
+ https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py
+License: MIT License.
+
+The following code is a modification of the original clDice implementation.
+Modifications were made to include additional functionality and integrate
+with new project requirements. The original license and copyright notice are
+provided below.
+
+MIT License
+
+Copyright (c) 2021 Johannes C. Paetzold and Suprosanna Shit
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch.nn.functional import sigmoid
+from torch.nn.modules.loss import _Loss
+
+
+def soft_erode(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor:
+ """
+ Apply soft erosion operation to the input image.
+
+ Soft erosion is achieved by applying a min-pooling operation to the input image.
+
+ Parameters
+ ----------
+ img : torch.Tensor
+ Input image tensor with shape (B, C, D, H, W)
+ separate_pool : bool, optional
+ If True, perform separate 3D max-pooling operations along different axes.
+ Default is False.
+
+ Returns
+ -------
+ torch.Tensor
+ Eroded image tensor with the same shape as the input.
+
+ Raises
+ ------
+ ValueError
+ If the input tensor has an unsupported number of dimensions.
+
+ Notes
+ -----
+ - The soft erosion can be performed with separate 3D min-pooling operations
+ along different axes if separate_pool is True, or with a single
+ 3D min-pooling operation with a kernel of size (3, 3, 3) if
+ separate_pool is False.
+ """
+ assert len(img.shape) == 5
+ if separate_pool:
+ p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0))
+ p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0))
+ p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1))
+ return torch.min(torch.min(p1, p2), p3)
+ p4 = -F.max_pool3d(-img, (3, 3, 3), (1, 1, 1), (1, 1, 1))
+ return p4
+
+
+def soft_dilate(img: torch.Tensor) -> torch.Tensor:
+ """
+ Apply soft dilation operation to the input image.
+
+ Soft dilation is achieved by applying a max-pooling operation to the input image.
+
+ Parameters
+ ----------
+ img : torch.Tensor
+ Input image tensor with shape (B, C, D, H, W).
+
+ Returns
+ -------
+ torch.Tensor
+ Dilated image tensor with the same shape as the input.
+
+ Raises
+ ------
+ ValueError
+ If the input tensor has an unsupported number of dimensions.
+
+ Notes
+ -----
+ - For 5D input, the soft dilation is performed using a 3D max-pooling operation
+ with a kernel of size (3, 3, 3).
+ """
+ assert len(img.shape) == 5
+ return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1))
+
+
+def soft_open(img: torch.Tensor, separate_pool: bool = False) -> torch.Tensor:
+ """
+ Apply soft opening operation to the input image.
+
+ Soft opening is achieved by applying soft erosion followed by soft dilation.
+ The intention of soft opening is to remove thin membranes from the segmentation.
+
+ Parameters
+ ----------
+ img : torch.Tensor
+ Input image tensor with shape (B, C, D, H, W).
+ separate_pool : bool, optional
+ If True, perform separate erosion and dilation operations. Default is False.
+
+ Returns
+ -------
+ torch.Tensor
+ Opened image tensor with the same shape as the input.
+
+ Notes
+ -----
+ - Soft opening is performed by applying soft erosion followed by soft dilation
+ to the input image.
+ - For 5D input, separate erosion and dilation can be performed if separate_pool
+ is True.
+ """
+ return soft_dilate(soft_erode(img, separate_pool=separate_pool))
+
+
+def soft_skel(
+ img: torch.Tensor, iter_: int, separate_pool: bool = False
+) -> torch.Tensor:
+ """
+ Compute the soft skeleton of the input image.
+
+ The skeleton is computed by applying soft erosion iteratively to the input image.
+ In each iteration, the difference between the input image and the "opened" image is
+ computed and added to the skeleton.
+
+ Reasoning: if there is a difference between the input image and the "opened" image,
+ there must be a thin membrane skeleton in the input image that was removed by the
+ opening operation.
+
+ Parameters
+ ----------
+ img : torch.Tensor
+ Input image tensor with shape (B, C, D, H, W).
+ iter_ : int
+ Number of iterations for skeletonization.
+ separate_pool : bool, optional
+ If True, perform separate erosion and dilation operations.
+ Default is False.
+
+ Returns
+ -------
+ torch.Tensor
+ Soft skeleton image tensor with the same shape as the input.
+
+ Notes
+ -----
+ - Separate erosion can be performed if separate_pool is True.
+ """
+ img1 = soft_open(img, separate_pool=separate_pool)
+ skel = F.relu(img - img1)
+ for _j in range(iter_):
+ img = soft_erode(img)
+ img1 = soft_open(img, separate_pool=separate_pool)
+ delta = F.relu(img - img1)
+ skel = skel + F.relu(delta - skel * delta)
+ return skel
+
+
+def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
+ """
+ Creates a 3D Gaussian kernel using the specified size and sigma.
+
+ Parameters
+ ----------
+ size : int
+ The size of the Gaussian kernel. It determines the length of
+ each dimension of the cube.
+ sigma : float
+ The standard deviation of the Gaussian kernel. It controls
+ the spread of the Gaussian.
+
+ Returns
+ -------
+ torch.Tensor
+ A 3D tensor representing the Gaussian kernel.
+
+ Notes
+ -----
+ The function creates a Gaussian kernel, which is essentially a
+ cube of dimensions [size, size, size]. Each entry in the cube is
+ computed using the Gaussian function based on its distance from the center.
+ The kernel is normalized so that its total sum equals 1.
+ """
+ # Define a coordinate grid centered at (0,0,0)
+ grid = torch.arange(size, dtype=torch.float32) - (size - 1) / 2
+ # Create a 3D meshgrid
+ x, y, z = torch.meshgrid(grid, grid, grid)
+ xyz_grid = torch.stack([x, y, z], dim=-1)
+
+ # Calculate the 3D Gaussian kernel
+ gaussian_kernel = torch.exp(-torch.sum(xyz_grid**2, dim=-1) / (2 * sigma**2))
+ gaussian_kernel /= (2 * math.pi * sigma**2) ** (3 / 2) # Normalize
+
+ # Ensure sum of values in gaussian kernel equals 1.
+ gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
+ return gaussian_kernel
+
+
+gaussian_kernel_dict = {}
+""" Not sure why, but moving the gaussian kernel to GPU takes surprisingly long.+
+So we precompute it, store it on GPU, and reuse it.
+"""
+
+
+def apply_gaussian_filter(
+ seg: torch.Tensor, kernel_size: int, sigma: float
+) -> torch.Tensor:
+ """
+ Apply a Gaussian filter to a segmentation tensor using PyTorch.
+
+ This function convolves the input tensor with a Gaussian kernel.
+ The function creates or retrieves a Gaussian kernel based on the
+ specified size and standard deviation, and applies 3D convolution to each
+ channel of each batch item with appropriate padding to maintain spatial
+ dimensions.
+
+ Parameters
+ ----------
+ seg : torch.Tensor
+ The input segmentation tensor of shape (batch, channel, X, Y, Z).
+ kernel_size : int
+ The size of the Gaussian kernel, determining the length of each
+ dimension of the cube.
+ sigma : float
+ The standard deviation of the Gaussian kernel, controlling the spread.
+
+ Returns
+ -------
+ torch.Tensor
+ The filtered segmentation tensor of the same shape as input.
+
+ Notes
+ -----
+ This function uses a precomputed dictionary to enhance performance by
+ storing Gaussian kernels. If a kernel with the specified size and standard
+ deviation does not exist in the dictionary, it is created and added. The
+ function assumes the input tensor is a 5D tensor, applies 3D convolution
+ using the Gaussian kernel with padding to maintain spatial dimensions, and
+ it performs the operation separately for each channel of each batch item.
+ """
+ # Create the Gaussian kernel or load it from the dictionary
+ if (kernel_size, sigma) not in gaussian_kernel_dict.keys():
+ gaussian_kernel_dict[(kernel_size, sigma)] = gaussian_kernel(
+ kernel_size, sigma
+ ).to(seg.device)
+ g_kernel = gaussian_kernel_dict[(kernel_size, sigma)]
+
+ # Add batch and channel dimensions
+ g_kernel = g_kernel.view(1, 1, *g_kernel.size())
+ # Apply the Gaussian filter to each channel
+ padding = kernel_size // 2
+
+ # Move the kernel to the same device as the segmentation tensor
+ g_kernel = g_kernel.to(seg.device)
+
+ # Apply the Gaussian filter
+ filtered_seg = F.conv3d(seg, g_kernel, padding=padding, groups=seg.shape[1])
+ return filtered_seg
+
+
+def get_GT_skeleton(gt_seg: torch.Tensor, iterations: int = 5) -> torch.Tensor:
+ """
+ Generate the skeleton of a ground truth segmentation.
+
+ This function takes a ground truth segmentation `gt_seg`, smooths it using a
+ Gaussian filter, and then computes its soft skeleton using the `soft_skel` function.
+
+ Intention: When using the binary ground truth segmentation for skeletonization,
+ the resulting skeleton is very patchy and not smooth. When using the smoothed
+ ground truth segmentation, the resulting skeleton is much smoother and more
+ accurate.
+
+ Parameters
+ ----------
+ gt_seg : torch.Tensor
+ A torch.Tensor representing the ground truth segmentation.
+ Shape: (B, C, D, H, W)
+ iterations : int, optional
+ The number of iterations for skeletonization. Default is 5.
+
+ Returns
+ -------
+ torch.Tensor
+ A torch.Tensor representing the skeleton of the ground truth segmentation.
+
+ Notes
+ -----
+ - The input `gt_seg` should be a binary segmentation tensor where 1 represents the
+ object of interest.
+ - The function first smooths the `gt_seg` using a Gaussian filter to enhance the
+ object's structure.
+ - The skeletonization process is performed using the `soft_skel` function with the
+ specified number of iterations.
+ - The resulting skeleton is returned as a binary torch.Tensor where 1 indicates the
+ skeleton points.
+ """
+ gt_smooth = (
+ apply_gaussian_filter((gt_seg == 1) * 1.0, kernel_size=15, sigma=2.0) * 1.5
+ )
+ skel_gt = soft_skel(gt_smooth, iter_=iterations)
+ return skel_gt
+
+
+def masked_surface_dice(
+ data: torch.Tensor,
+ target: torch.Tensor,
+ ignore_label: int = 2,
+ soft_skel_iterations: int = 3,
+ smooth: float = 3.0,
+ binary_prediction: bool = False,
+ reduction: str = "none",
+) -> torch.Tensor:
+ """
+ Compute the surface Dice loss with masking for ignore labels.
+
+ The surface Dice loss measures the similarity between the predicted segmentation's
+ skeleton and the ground truth segmentation (and vice versa). Labels annotated with
+ "ignore_label" are ignored.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Tensor of model outputs representing the predicted segmentation.
+ Expected shape: (B, C, D, H, W)
+ target : torch.Tensor
+ Tensor of target labels representing the ground truth segmentation.
+ Expected shape: (B, 1, D, H, W)
+ ignore_label : int
+ The label value to be ignored when computing the loss.
+ soft_skel_iterations : int
+ Number of iterations for skeletonization in the underlying operations.
+ smooth : float
+ Smoothing factor to avoid division by zero.
+ binary_prediction : bool
+ If True, the predicted segmentation is assumed to be binary. Default is False.
+ reduction : str
+ Specifies the reduction to apply to the output. Default is "none".
+
+ Returns
+ -------
+ torch.Tensor
+ The calculated surface Dice loss.
+ """
+ # Create a mask to ignore the specified label in the target
+ data = sigmoid(data)
+ mask = target != ignore_label
+
+ # Compute soft skeletonization
+ if binary_prediction:
+ skel_pred = get_GT_skeleton(data.clone(), soft_skel_iterations)
+ else:
+ skel_pred = soft_skel(data.clone(), soft_skel_iterations, separate_pool=False)
+ skel_true = get_GT_skeleton(target.clone(), soft_skel_iterations)
+
+ # Mask out ignore labels
+ skel_pred[~mask] = 0
+ skel_true[~mask] = 0
+
+ # compute surface dice loss
+ tprec = (
+ torch.sum(torch.multiply(skel_pred, target), dim=(1, 2, 3, 4)) + smooth
+ ) / (torch.sum(skel_pred, dim=(1, 2, 3, 4)) + smooth)
+ tsens = (torch.sum(torch.multiply(skel_true, data), dim=(1, 2, 3, 4)) + smooth) / (
+ torch.sum(skel_true, dim=(1, 2, 3, 4)) + smooth
+ )
+ surf_dice_loss = 2.0 * (tprec * tsens) / (tprec + tsens)
+ if reduction == "none":
+ return surf_dice_loss
+ elif reduction == "mean":
+ return torch.mean(surf_dice_loss)
+
+
+class IgnoreLabelSurfaceDiceLoss(_Loss):
+ """
+ Surface Dice loss, adding ignore labels.
+
+ Parameters
+ ----------
+ ignore_label : int
+ The label to ignore when calculating the loss.
+ reduction : str, optional
+ Specifies the reduction to apply to the output, by default "mean".
+ kwargs : dict
+ Additional keyword arguments.
+ """
+
+ def __init__(
+ self,
+ ignore_label: int,
+ soft_skel_iterations: int = 3,
+ smooth: float = 3.0,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.ignore_label = ignore_label
+ self.soft_skel_iterations = soft_skel_iterations
+ self.smooth = smooth
+
+ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the loss.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Tensor of model outputs.
+ Expected shape: (B, C, D, H, W)
+ target : torch.Tensor
+ Tensor of target labels.
+ Expected shape: (B, 1, D, H, W)
+
+ Returns
+ -------
+ torch.Tensor
+ The calculated loss.
+ """
+ # Create a mask to ignore the specified label in the target
+ surf_dice_score = masked_surface_dice(
+ data=data,
+ target=target,
+ ignore_label=self.ignore_label,
+ soft_skel_iterations=self.soft_skel_iterations,
+ smooth=self.smooth,
+ )
+ surf_dice_loss = 1.0 - surf_dice_score
+ return surf_dice_loss
diff --git a/src/membrain_seg/segmentation/training/training_param_summary.py b/src/membrain_seg/segmentation/training/training_param_summary.py
new file mode 100644
index 0000000..67277c8
--- /dev/null
+++ b/src/membrain_seg/segmentation/training/training_param_summary.py
@@ -0,0 +1,117 @@
+def print_training_parameters(
+ data_dir: str = "",
+ log_dir: str = "logs/",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ max_epochs: int = 1000,
+ aug_prob_to_one: bool = False,
+ use_deep_supervision: bool = False,
+ project_name: str = "membrain-seg_v0",
+ sub_name: str = "1",
+ use_surf_dice: bool = False,
+ surf_dice_weight: float = 1.0,
+ surf_dice_tokens: list = None,
+):
+ """
+ Print a formatted overview of the training parameters with explanations.
+
+ Parameters
+ ----------
+ data_dir : str, optional
+ Path to the directory containing training data.
+ log_dir : str, optional
+ Path to the directory where logs should be stored.
+ batch_size : int, optional
+ Number of samples per batch of input data.
+ num_workers : int, optional
+ Number of subprocesses to use for data loading.
+ max_epochs : int, optional
+ Maximum number of epochs to train for.
+ aug_prob_to_one : bool, optional
+ If True, all augmentation probabilities are set to 1.
+ use_deep_supervision : bool, optional
+ If True, enables deep supervision in the U-Net model.
+ project_name : str, optional
+ Name of the project for logging purposes.
+ sub_name : str, optional
+ Sub-name of the project for logging purposes.
+ use_surf_dice : bool, optional
+ If True, enables Surface-Dice loss.
+ surf_dice_weight : float, optional
+ Weight for the Surface-Dice loss.
+ surf_dice_tokens : list, optional
+ List of tokens to use for the Surface-Dice loss.
+
+ Returns
+ -------
+ None
+ """
+ print("\033[1mTraining Parameters Overview:\033[0m\n")
+ print(
+ "Data Directory:\n '{}' \n Path to the directory containing "
+ "training data.".format(data_dir)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(
+ "Log Directory:\n '{}' \n Directory where logs and outputs will "
+ "be stored.".format(log_dir)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(
+ "Batch Size:\n {} \n Number of samples processed in a single batch.".format(
+ batch_size
+ )
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(
+ "Number of Workers:\n {} \n Subprocesses to use for data "
+ "loading.".format(num_workers)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(f"Max Epochs:\n {max_epochs} \n Maximum number of training epochs.")
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ aug_status = "Enabled" if aug_prob_to_one else "Disabled"
+ print(
+ "Augmentation Probability to One:\n {} \n If enabled, sets all "
+ "augmentation probabilities to 1. (strong augmentation)".format(aug_status)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ deep_sup_status = "Enabled" if use_deep_supervision else "Disabled"
+ print(
+ "Use Deep Supervision:\n {} \n If enabled, activates deep "
+ "supervision in model.".format(deep_sup_status)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(
+ "Project Name:\n '{}' \n Name identifier for the current"
+ " training session.".format(project_name)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(
+ "Sub Name:\n '{}' \n Additional sub-identifier for organizing"
+ " outputs.".format(sub_name)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ surf_dice_status = "Enabled" if use_surf_dice else "Disabled"
+ print(
+ "Use Surface Dice:\n {} \n If enabled, includes Surface-Dice in the loss "
+ "calculation.".format(surf_dice_status)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ print(
+ "Surface Dice Weight:\n {} \n Weighting of the Surface-Dice"
+ " loss, if enabled.".format(surf_dice_weight)
+ )
+ print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
+ if surf_dice_tokens:
+ tokens = ", ".join(surf_dice_tokens)
+ print(
+ "Surface Dice Tokens:\n [{}] \n Specific tokens used for "
+ "Surface-Dice loss. Other tokens will be neglected.".format(tokens)
+ )
+ else:
+ print(
+ "Surface Dice Tokens:\n None \n No specific tokens are used for "
+ "Surface-Dice loss."
+ )
+ print("\n")
diff --git a/tests/membrain_seg/training/test_optim_utils.py b/tests/membrain_seg/training/test_optim_utils.py
index 4151d7a..dd0433c 100644
--- a/tests/membrain_seg/training/test_optim_utils.py
+++ b/tests/membrain_seg/training/test_optim_utils.py
@@ -12,6 +12,7 @@ def test_loss_fn_correctness():
import torch
from membrain_seg.segmentation.training.optim_utils import (
+ CombinedLoss,
DeepSuperVisionLoss,
IgnoreLabelDiceCELoss,
)
@@ -73,17 +74,24 @@ def extend_labels(labels):
pred_labels[2][pred_labels[2] < 0.0] = 0.0
ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean")
+ combined_loss = CombinedLoss(
+ losses=[ignore_dice_loss], weights=[1.0], loss_inclusion_tokens=["ds1"]
+ )
losses = test_ignore_dice_loss(ignore_dice_loss, pred_labels, gt_labels)
assert losses[0] == losses[1] == losses[2] == losses[4] != losses[3]
deep_supervision_loss = DeepSuperVisionLoss(
- ignore_dice_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675]
+ combined_loss, weights=[1.0, 0.5, 0.25, 0.125, 0.0675]
)
gt_labels_ds = extend_labels(gt_labels)
ds_losses = []
for pred_label in pred_labels:
pred_labels_ds = extend_labels(pred_label)
- ds_losses.append(deep_supervision_loss(pred_labels_ds, gt_labels_ds))
+ ds_losses.append(
+ deep_supervision_loss(
+ pred_labels_ds, gt_labels_ds, ["ds1"] * len(gt_labels_ds)
+ )
+ )
assert ds_losses[0] == ds_losses[1] == ds_losses[2] == ds_losses[4] != ds_losses[3]
print("All ignore loss assertions passed.")
From 560beb3637bf22972f9e5ef7dc4f571b82cca678 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 23 Jan 2024 07:49:56 +0100
Subject: [PATCH 8/8] ci(dependabot): bump actions/checkout from 3 to 4 (#34)
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)
---
updated-dependencies:
- dependency-name: actions/checkout
dependency-type: direct:production
update-type: version-update:semver-major
...
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Kevin Yamauchi
---
.github/workflows/build-and-deploy-docs.yml | 2 +-
.github/workflows/ci.yml | 4 +---
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/build-and-deploy-docs.yml b/.github/workflows/build-and-deploy-docs.yml
index 20892fe..2de9d12 100644
--- a/.github/workflows/build-and-deploy-docs.yml
+++ b/.github/workflows/build-and-deploy-docs.yml
@@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index c13bf4f..552caad 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -38,7 +38,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- - name: ๐ Set up Python ${{ matrix.python-version }}
+ - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
@@ -86,8 +86,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- name: ๐ Set up Python
uses: actions/setup-python@v4