-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Surface dice loss #45
Conversation
@@ -21,6 +21,11 @@ def extract_patches( | |||
help="Path to the folder where extracted patches should be stored. \ | |||
(subdirectories will be created)", | |||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset token is now readable as well. This helps to distinguish between different datasets, because we may want to apply different loss functions (particularly Surface-Dice) to some datasets, but not to others.
@@ -84,6 +87,22 @@ def train_advanced( | |||
but also severely increases training time.\ | |||
Pass "True" or "False".', | |||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surface_dice_tokens: dataset tokens specifying which datasets to apply surface-dice to.
Needs to be passed as separate arguments:
--surface-dice-tokens ds1 --surface-dice-tokens ds2
I did not find a more elegant way with Typer to pass in a list of strings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah that's a bummer, but I think it's okay. In the future, it might make sense to add support for a glob
string or something or directory so that people don't have to write all of the tokens.
@@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: | |||
"label": np.expand_dims(self.labels[idx], 0), | |||
} | |||
idx_dict = self.transforms(idx_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset token is now returned with every train image
@@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None: | |||
os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"), | |||
test_sample["label"][1][0, :, :, num_mask], | |||
) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset token is defined as first token before 1st underscore
deep_supervision=True, | ||
deep_supr_num=2, | ||
) | ||
ignore_dice_loss = IgnoreLabelDiceCELoss(ignore_label=2, reduction="mean") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following initializes a weighted average between loss functions.
BCE/DICE loss are used by default, Surface-Dice is optionally added
Codecov ReportAttention:
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #45 +/- ##
========================================
+ Coverage 5.41% 7.73% +2.31%
========================================
Files 38 40 +2
Lines 1256 1410 +154
========================================
+ Hits 68 109 +41
- Misses 1188 1301 +113 ☔ View full report in Codecov by Sentry. |
loss_inclusion_tokens.append(surf_dice_tokens) | ||
|
||
scaled_weights = [entry / sum(weights) for entry in weights] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The combined loss function computes the losses only for selected datasets.
BCE/DICE are computed for all datasets, for S-Dice datasets can be custom-chosen
output = self.forward(images) | ||
loss = self.loss_function(output, labels) | ||
loss = self.loss_function(output, labels, ds_label) | ||
|
||
stats_dict = {"train_loss": loss, "train_number": output[0].shape[0]} | ||
self.training_step_outputs.append(stats_dict) | ||
self.running_train_acc += ( | ||
masked_accuracy(output[0], labels[0], ignore_label=2.0, threshold_value=0.0) | ||
* output[0].shape[0] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also log surface-dice during training
surf_dice_weight : float, optional | ||
Weight for the Surface-Dice loss. | ||
surf_dice_tokens : list, optional | ||
List of tokens to use for the Surface-Dice loss. | ||
|
||
Returns | ||
------- | ||
None | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This prints a summary of training parameters before each training run.
@@ -106,7 +133,7 @@ def on_epoch_start(self, trainer, pl_module): | |||
# Set up the trainer | |||
trainer = pl.Trainer( | |||
precision="16-mixed", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wandb logging requires additional dependency & wandb registration. Should be an option in the future, but removed for now
@@ -34,7 +34,7 @@ def masked_accuracy( | |||
mask = ( | |||
y_gt == ignore_label | |||
if ignore_label is not None | |||
else torch.ones_like(y_gt).bool() | |||
else torch.zeros_like(y_gt).bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed issue mentioned in #44
@@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss): | |||
def __init__( | |||
self, | |||
ignore_label: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loss functions should be default return non-reduced losses. I.e., for each element in the batch, a single value should be returned.
This way, we can decide whether to apply the respective loss for each element in the batch (depending on the dataset token)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
return loss | ||
|
||
|
||
class CombinedLoss(_Loss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combined loss function computes the weighted averages of all losses, and considers only specified datasets.
In this way, we can choose exactly which losses to apply for which dataset
@@ -0,0 +1,455 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation of Surface-Dice functionalities
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Soft skeletonization of the segmentation is achieved by iteratively eroding and dilating the membrane, and keeping track of the differences.
Erosion and Dilation can be achieved by min- and max-pooling, respectively. This makes the function differentiable and we can perform backpropagation through it.
skel = skel + F.relu(delta - skel * delta) | ||
return skel | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defined Gaussian kernel for smoothing binary segmentations before skeletonization.
I didn't find another torch function for this, so I implemented it.
This allows computation of smoothing on GPU without moving stuff between devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
filtered_seg = F.conv3d(seg, g_kernel, padding=padding, groups=seg.shape[1]) | ||
return filtered_seg | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for binary segmentations, we first perform Gaussian smoothing. Otherwise soft skeletons are discontinuous
skel_gt = soft_skel(gt_smooth, iter_=iterations) | ||
return skel_gt | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Computation of Surface-Dice similar to centerline Dice.
@@ -0,0 +1,117 @@ | |||
def print_training_parameters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it might be useful to print training parameters before each training run. Maybe makes sense to make this optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted test function to work with new loss definitions.
I should cover more code with the tests I guess :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating. It would be nice to increase the test coverage over time, but it's fine for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! I have some minor comments below. Based on the conversation on zulip with @alisterburt , it sounds like we are in agreement that this loss function should live in membrain-seg
. I think you can merge after you address the minor comments. Thanks, @LorenzLamm !
|
||
Parameters | ||
---------- | ||
data : torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the code assumes a certain shape for the data array. It would be nice to write the expected axis order in the docstring (e..g, [B, C, Z, Y, X]
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same in the class below)
it performs the operation separately for each channel of each batch item. | ||
""" | ||
# Create the Gaussian kernel or load it from the dictionary | ||
global gaussian_kernel_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to use a global here?
skel = skel + F.relu(delta - skel * delta) | ||
return skel | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating. It would be nice to increase the test coverage over time, but it's fine for this PR
@@ -84,6 +87,22 @@ def train_advanced( | |||
but also severely increases training time.\ | |||
Pass "True" or "False".', | |||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah that's a bummer, but I think it's okay. In the future, it might make sense to add support for a glob
string or something or directory so that people don't have to write all of the tokens.
@@ -53,7 +53,7 @@ class IgnoreLabelDiceCELoss(_Loss): | |||
def __init__( | |||
self, | |||
ignore_label: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
if self.reduction == "mean": | ||
combined_loss = combined_loss.mean() | ||
elif self.reduction == "sum": | ||
combined_loss = combined_loss.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it woudl be nice to add and else
case that raises an error so that it doesn't fail silently if people make a typo
Cool, thanks a lot for your feedback @kevinyamauchi ! Implemented your suggestions and merging now. |
Added the option to use Surface-Dice as a loss function during training.
Surface-Dice is based on "clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation"
(https://openaccess.thecvf.com/content/CVPR2021/papers/Shit_clDice_-_A_Novel_Topology-Preserving_Loss_Function_for_Tubular_Structure_CVPR_2021_paper.pdf)
Also fixed some issues for patch extraction (corrected naming), removed wandb tracking (caused dependency issues), fixed bug mentioned in #44, and added printing of training parameter summary.