Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surface dice loss #45

Merged
merged 23 commits into from
Jan 22, 2024
Merged

Surface dice loss #45

merged 23 commits into from
Jan 22, 2024

Conversation

LorenzLamm
Copy link
Collaborator

Added the option to use Surface-Dice as a loss function during training.

Surface-Dice is based on "clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation"
(https://openaccess.thecvf.com/content/CVPR2021/papers/Shit_clDice_-_A_Novel_Topology-Preserving_Loss_Function_for_Tubular_Structure_CVPR_2021_paper.pdf)

Also fixed some issues for patch extraction (corrected naming), removed wandb tracking (caused dependency issues), fixed bug mentioned in #44, and added printing of training parameter summary.

@@ -21,6 +21,11 @@ def extract_patches(
help="Path to the folder where extracted patches should be stored. \
(subdirectories will be created)",
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset token is now readable as well. This helps to distinguish between different datasets, because we may want to apply different loss functions (particularly Surface-Dice) to some datasets, but not to others.

@@ -84,6 +87,22 @@ def train_advanced(
but also severely increases training time.\
Pass "True" or "False".',
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surface_dice_tokens: dataset tokens specifying which datasets to apply surface-dice to.

Needs to be passed as separate arguments:
--surface-dice-tokens ds1 --surface-dice-tokens ds2

I did not find a more elegant way with Typer to pass in a list of strings

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's a bummer, but I think it's okay. In the future, it might make sense to add support for a glob string or something or directory so that people don't have to write all of the tokens.

@@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"label": np.expand_dims(self.labels[idx], 0),
}
idx_dict = self.transforms(idx_dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset token is now returned with every train image

@@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None:
os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"),
test_sample["label"][1][0, :, :, num_mask],
)


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset token is defined as first token before 1st underscore

deep_supervision=True,
deep_supr_num=2,
)
ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following initializes a weighted average between loss functions.
BCE/DICE loss are used by default, Surface-Dice is optionally added

@codecov-commenter
Copy link

codecov-commenter commented Jan 3, 2024

Codecov Report

Attention: 125 lines in your changes are missing coverage. Please review.

Comparison is base (1f9747a) 5.41% compared to head (9f39fc7) 7.73%.

Files Patch % Lines
...membrain_seg/segmentation/training/surface_dice.py 23.37% 59 Missing ⚠️
...eg/segmentation/training/training_param_summary.py 0.00% 32 Missing ⚠️
src/membrain_seg/segmentation/networks/unet.py 5.26% 18 Missing ⚠️
...ain_seg/segmentation/dataloading/memseg_dataset.py 0.00% 7 Missing ⚠️
.../membrain_seg/segmentation/training/optim_utils.py 90.00% 3 Missing ⚠️
src/membrain_seg/annotations/merge_corrections.py 0.00% 2 Missing ⚠️
src/membrain_seg/segmentation/cli/train_cli.py 0.00% 2 Missing ⚠️
src/membrain_seg/segmentation/train.py 0.00% 2 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff            @@
##            main     #45      +/-   ##
========================================
+ Coverage   5.41%   7.73%   +2.31%     
========================================
  Files         38      40       +2     
  Lines       1256    1410     +154     
========================================
+ Hits          68     109      +41     
- Misses      1188    1301     +113     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

loss_inclusion_tokens.append(surf_dice_tokens)

scaled_weights = [entry / sum(weights) for entry in weights]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combined loss function computes the losses only for selected datasets.
BCE/DICE are computed for all datasets, for S-Dice datasets can be custom-chosen

output = self.forward(images)
loss = self.loss_function(output, labels)
loss = self.loss_function(output, labels, ds_label)

stats_dict = {"train_loss": loss, "train_number": output[0].shape[0]}
self.training_step_outputs.append(stats_dict)
self.running_train_acc += (
masked_accuracy(output[0], labels[0], ignore_label=2.0, threshold_value=0.0)
* output[0].shape[0]
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also log surface-dice during training

surf_dice_weight : float, optional
Weight for the Surface-Dice loss.
surf_dice_tokens : list, optional
List of tokens to use for the Surface-Dice loss.

Returns
-------
None
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prints a summary of training parameters before each training run.

@@ -106,7 +133,7 @@ def on_epoch_start(self, trainer, pl_module):
# Set up the trainer
trainer = pl.Trainer(
precision="16-mixed",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb logging requires additional dependency & wandb registration. Should be an option in the future, but removed for now

@@ -34,7 +34,7 @@ def masked_accuracy(
mask = (
y_gt == ignore_label
if ignore_label is not None
else torch.ones_like(y_gt).bool()
else torch.zeros_like(y_gt).bool()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed issue mentioned in #44

@@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss):
def __init__(
self,
ignore_label: int,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss functions should be default return non-reduced losses. I.e., for each element in the batch, a single value should be returned.
This way, we can decide whether to apply the respective loss for each element in the batch (depending on the dataset token)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

return loss


class CombinedLoss(_Loss):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined loss function computes the weighted averages of all losses, and considers only specified datasets.
In this way, we can choose exactly which losses to apply for which dataset

@@ -0,0 +1,455 @@
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation of Surface-Dice functionalities

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Soft skeletonization of the segmentation is achieved by iteratively eroding and dilating the membrane, and keeping track of the differences.
Erosion and Dilation can be achieved by min- and max-pooling, respectively. This makes the function differentiable and we can perform backpropagation through it.

skel = skel + F.relu(delta - skel * delta)
return skel


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defined Gaussian kernel for smoothing binary segmentations before skeletonization.
I didn't find another torch function for this, so I implemented it.
This allows computation of smoothing on GPU without moving stuff between devices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

filtered_seg = F.conv3d(seg, g_kernel, padding=padding, groups=seg.shape[1])
return filtered_seg


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for binary segmentations, we first perform Gaussian smoothing. Otherwise soft skeletons are discontinuous

skel_gt = soft_skel(gt_smooth, iter_=iterations)
return skel_gt


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computation of Surface-Dice similar to centerline Dice.

@@ -0,0 +1,117 @@
def print_training_parameters(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it might be useful to print training parameters before each training run. Maybe makes sense to make this optional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted test function to work with new loss definitions.
I should cover more code with the tests I guess :/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating. It would be nice to increase the test coverage over time, but it's fine for this PR

Copy link
Collaborator

@kevinyamauchi kevinyamauchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! I have some minor comments below. Based on the conversation on zulip with @alisterburt , it sounds like we are in agreement that this loss function should live in membrain-seg. I think you can merge after you address the minor comments. Thanks, @LorenzLamm !


Parameters
----------
data : torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the code assumes a certain shape for the data array. It would be nice to write the expected axis order in the docstring (e..g, [B, C, Z, Y, X])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same in the class below)

it performs the operation separately for each channel of each batch item.
"""
# Create the Gaussian kernel or load it from the dictionary
global gaussian_kernel_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to use a global here?

skel = skel + F.relu(delta - skel * delta)
return skel


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating. It would be nice to increase the test coverage over time, but it's fine for this PR

@@ -84,6 +87,22 @@ def train_advanced(
but also severely increases training time.\
Pass "True" or "False".',
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's a bummer, but I think it's okay. In the future, it might make sense to add support for a glob string or something or directory so that people don't have to write all of the tokens.

@@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss):
def __init__(
self,
ignore_label: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

if self.reduction == "mean":
combined_loss = combined_loss.mean()
elif self.reduction == "sum":
combined_loss = combined_loss.sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it woudl be nice to add and else case that raises an error so that it doesn't fail silently if people make a typo

@LorenzLamm
Copy link
Collaborator Author

This looks good to me! I have some minor comments below. Based on the conversation on zulip with @alisterburt , it sounds like we are in agreement that this loss function should live in membrain-seg. I think you can merge after you address the minor comments. Thanks, @LorenzLamm !

Cool, thanks a lot for your feedback @kevinyamauchi ! Implemented your suggestions and merging now.

@LorenzLamm LorenzLamm merged commit 49aa798 into main Jan 22, 2024
5 checks passed
@LorenzLamm LorenzLamm deleted the SurfaceDice branch January 22, 2024 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants