Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surface dice loss #45

Merged
merged 23 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a45d037
Surface-Dice functionalities
LorenzLamm Dec 31, 2023
df4a376
Adjust losses to be compatible with Surface-Dice exclusions
LorenzLamm Dec 31, 2023
79859e0
Adjust training routine and include surface dice loss
LorenzLamm Dec 31, 2023
4868e98
Add dataset labels to dataloading
LorenzLamm Dec 31, 2023
7bf110b
Pass Surface-Dice arguments to training routine
LorenzLamm Dec 31, 2023
860aa1c
Update CLI to include advanced options for Surface-Dice
LorenzLamm Dec 31, 2023
bd543f5
precommit formatting
LorenzLamm Dec 31, 2023
a3c0ce2
make list readable by passing argument multiple times
LorenzLamm Jan 3, 2024
9660790
remove redundant import
LorenzLamm Jan 3, 2024
94a1c81
Compatibility with updated masked_surface_dice function
LorenzLamm Jan 3, 2024
722a5ed
Add training summary and remove wandb logging
LorenzLamm Jan 3, 2024
044a7d6
remove reduntant print statements and include ds_labels into for-loop
LorenzLamm Jan 3, 2024
84c2b0b
Implement Gaussian smoothing with torch to compute everything on GPU
LorenzLamm Jan 3, 2024
9b3ae9b
Training summary printing
LorenzLamm Jan 3, 2024
44bd6a1
add dataset token to CLI
LorenzLamm Jan 3, 2024
b577c5e
Add dataset token to filename
LorenzLamm Jan 3, 2024
5373977
Update warnings
LorenzLamm Jan 3, 2024
b4c9e43
Fix bug for accuracy masking
LorenzLamm Jan 3, 2024
6d1b0e5
Fix Dice reduction to scalar
LorenzLamm Jan 3, 2024
bc8acf0
Make test compatible with CombinedLoss
LorenzLamm Jan 3, 2024
9f39fc7
Fix default path
LorenzLamm Jan 3, 2024
0547611
Raise Error when reduction is not defined
LorenzLamm Jan 22, 2024
346d850
Add required dimensions to docstrings
LorenzLamm Jan 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/membrain_seg/annotations/extract_patch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def extract_patches(
help="Path to the folder where extracted patches should be stored. \
(subdirectories will be created)",
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset token is now readable as well. This helps to distinguish between different datasets, because we may want to apply different loss functions (particularly Surface-Dice) to some datasets, but not to others.

ds_token: str = Option( # noqa: B008
"other",
help="Dataset token. Important for distinguishing between different \
datasets. Should NOT contain underscores!",
),
coords_file: str = Option( # noqa: B008
None,
help="Path to a file containing coordinates for patch extraction. The file \
Expand Down Expand Up @@ -93,6 +98,7 @@ def extract_patches(
coords=coords,
out_dir=out_folder,
idx_add=idx_add,
ds_token=ds_token,
token=token,
pad_value=pad_value,
)
31 changes: 21 additions & 10 deletions src/membrain_seg/annotations/extract_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def pad_labels(patch, padding, pad_value=2.0):


def get_out_files_and_patch_number(
token, out_folder_raw, out_folder_lab, patch_nr, idx_add
ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add
):
"""
Create filenames and corrected patch numbers.
Expand All @@ -62,8 +62,10 @@ def get_out_files_and_patch_number(

Parameters
----------
ds_token : str
The dataset identifier used as a part of the filename.
token : str
The unique identifier used as a part of the filename.
The tomogram identifier used as a part of the filename.
out_folder_raw : str
The directory path where raw data patches are stored.
out_folder_lab : str
Expand Down Expand Up @@ -96,27 +98,34 @@ def get_out_files_and_patch_number(
"""
patch_nr += idx_add
out_file_patch = os.path.join(
out_folder_raw, token + "_patch" + str(patch_nr) + "_raw.nii.gz"
out_folder_raw, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz"
)
out_file_patch_label = os.path.join(
out_folder_lab, token + "_patch" + str(patch_nr) + "_labels.nii.gz"
out_folder_lab, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz"
)
exist_add = 0
while os.path.isfile(out_file_patch):
exist_add += 1
out_file_patch = os.path.join(
out_folder_raw,
token + "_patch" + str(patch_nr + exist_add) + "_raw.nii.gz",
ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz",
)
out_file_patch_label = os.path.join(
out_folder_lab,
token + "_patch" + str(patch_nr + exist_add) + "_labels.nii.gz",
ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz",
)
return patch_nr + exist_add, out_file_patch, out_file_patch_label


def extract_patches(
tomo_path, seg_path, coords, out_dir, idx_add=0, token=None, pad_value=2.0
tomo_path,
seg_path,
coords,
out_dir,
ds_token="other",
token=None,
idx_add=0,
pad_value=2.0,
):
"""
Extracts 3D patches from a given tomogram and corresponding segmentation.
Expand All @@ -133,11 +142,13 @@ def extract_patches(
List of tuples where each tuple represents the 3D coordinates of a patch center.
out_dir : str
The output directory where the extracted patches will be saved.
idx_add : int, optional
The index addition for patch numbering, default is 0.
ds_token : str, optional
Dataset token to uniquely identify the dataset, default is 'other'.
token : str, optional
Token to uniquely identify the tomogram, default is None. If None,
the base name of the tomogram file path is used.
idx_add : int, optional
The index addition for patch numbering, default is 0.
pad_value: float, optional
Borders of extracted patch are padded with this value ("ignore" label)

Expand Down Expand Up @@ -170,7 +181,7 @@ def extract_patches(

for patch_nr, cur_coords in enumerate(coords):
patch_nr, out_file_patch, out_file_patch_label = get_out_files_and_patch_number(
token, out_folder_raw, out_folder_lab, patch_nr, idx_add
ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add
)
print("Extracting patch nr", patch_nr, "from tomo", token)
try:
Expand Down
8 changes: 5 additions & 3 deletions src/membrain_seg/annotations/merge_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ def get_corrections_from_folder(folder_name, orig_pred_file):
or filename.startswith("Ignore")
or filename.startswith("ignore")
):
print("ATTENTION! Not processing", filename)
print("Is this intended?")
print(
"File does not fit into Add/Remove/Ignore naming! " "Not processing",
filename,
)
continue
readdata = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(folder_name, filename))
)
print("Adding file", filename, "<--")
print("Adding file", filename)

if filename.startswith("Add") or filename.startswith("add"):
add_patch += readdata
Expand Down
30 changes: 29 additions & 1 deletion src/membrain_seg/segmentation/cli/train_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Optional

from typer import Option
from typing_extensions import Annotated

from ..train import train as _train
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
Expand Down Expand Up @@ -70,7 +73,7 @@ def train_advanced(
help="Batch size for training.",
),
num_workers: int = Option( # noqa: B008
1,
8,
help="Number of worker threads for loading data",
),
max_epochs: int = Option( # noqa: B008
Expand All @@ -84,6 +87,22 @@ def train_advanced(
but also severely increases training time.\
Pass "True" or "False".',
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surface_dice_tokens: dataset tokens specifying which datasets to apply surface-dice to.

Needs to be passed as separate arguments:
--surface-dice-tokens ds1 --surface-dice-tokens ds2

I did not find a more elegant way with Typer to pass in a list of strings

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's a bummer, but I think it's okay. In the future, it might make sense to add support for a glob string or something or directory so that people don't have to write all of the tokens.

use_surface_dice: bool = Option( # noqa: B008
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".'
),
surface_dice_weight: float = Option( # noqa: B008
1.0, help="Scaling factor for the Surface-Dice loss. "
),
surface_dice_tokens: Annotated[
Optional[List[str]],
Option(
help='List of tokens to \
use for the Surface-Dice loss. \
Pass tokens separately:\
For example, train_advanced --surface_dice_tokens "ds1" \
--surface_dice_tokens "ds2"'
),
] = None,
use_deep_supervision: bool = Option( # noqa: B008
True, help='Whether to use deep supervision. Pass "True" or "False".'
),
Expand Down Expand Up @@ -119,6 +138,12 @@ def train_advanced(
If set to False, data augmentation still happens, but not as frequently.
More data augmentation can lead to a better performance, but also increases the
training time substantially.
use_surface_dice : bool
Determines whether to use Surface-Dice loss, by default True.
surface_dice_weight : float
Scaling factor for the Surface-Dice loss, by default 1.0.
surface_dice_tokens : list
List of tokens to use for the Surface-Dice loss, by default ["all"].
use_deep_supervision : bool
Determines whether to use deep supervision, by default True.
project_name : str
Expand All @@ -140,6 +165,9 @@ def train_advanced(
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
use_surf_dice=use_surface_dice,
surf_dice_weight=surface_dice_weight,
surf_dice_tokens=surface_dice_tokens,
project_name=project_name,
sub_name=sub_name,
)
Expand Down
24 changes: 23 additions & 1 deletion src/membrain_seg/segmentation/dataloading/memseg_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from typing import Dict

# from skimage import io
import imageio as io
import numpy as np
from torch.utils.data import Dataset
Expand Down Expand Up @@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"label": np.expand_dims(self.labels[idx], 0),
}
idx_dict = self.transforms(idx_dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset token is now returned with every train image

idx_dict["dataset"] = self.dataset_labels[idx]
return idx_dict

def __len__(self) -> int:
Expand All @@ -126,6 +126,7 @@ def load_data(self) -> None:
print("Loading images into dataset.")
self.imgs = []
self.labels = []
self.dataset_labels = []
for entry in self.data_paths:
label = read_nifti(
entry[1]
Expand All @@ -137,6 +138,7 @@ def load_data(self) -> None:
img = np.transpose(img, (1, 2, 0))
self.imgs.append(img)
self.labels.append(label)
self.dataset_labels.append(get_dataset_token(entry[0]))

def initialize_imgs_paths(self) -> None:
"""
Expand Down Expand Up @@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None:
os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"),
test_sample["label"][1][0, :, :, num_mask],
)


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset token is defined as first token before 1st underscore

def get_dataset_token(patch_name):
"""
Get the dataset token from the patch name.

Parameters
----------
patch_name : str
The name of the patch.

Returns
-------
str
The dataset token.

"""
basename = os.path.basename(patch_name)
dataset_token = basename.split("_")[0]
return dataset_token
Loading