-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Surface dice loss #45
Changes from 20 commits
a45d037
df4a376
79859e0
4868e98
7bf110b
860aa1c
bd543f5
a3c0ce2
9660790
94a1c81
722a5ed
044a7d6
84c2b0b
9b3ae9b
44bd6a1
b577c5e
5373977
b4c9e43
6d1b0e5
bc8acf0
9f39fc7
0547611
346d850
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
from typing import List, Optional | ||
|
||
from typer import Option | ||
from typing_extensions import Annotated | ||
|
||
from ..train import train as _train | ||
from .cli import OPTION_PROMPT_KWARGS as PKWARGS | ||
|
@@ -70,7 +73,7 @@ def train_advanced( | |
help="Batch size for training.", | ||
), | ||
num_workers: int = Option( # noqa: B008 | ||
1, | ||
8, | ||
help="Number of worker threads for loading data", | ||
), | ||
max_epochs: int = Option( # noqa: B008 | ||
|
@@ -84,6 +87,22 @@ def train_advanced( | |
but also severely increases training time.\ | ||
Pass "True" or "False".', | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. surface_dice_tokens: dataset tokens specifying which datasets to apply surface-dice to. Needs to be passed as separate arguments: I did not find a more elegant way with Typer to pass in a list of strings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah that's a bummer, but I think it's okay. In the future, it might make sense to add support for a |
||
use_surface_dice: bool = Option( # noqa: B008 | ||
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' | ||
), | ||
surface_dice_weight: float = Option( # noqa: B008 | ||
1.0, help="Scaling factor for the Surface-Dice loss. " | ||
), | ||
surface_dice_tokens: Annotated[ | ||
Optional[List[str]], | ||
Option( | ||
help='List of tokens to \ | ||
use for the Surface-Dice loss. \ | ||
Pass tokens separately:\ | ||
For example, train_advanced --surface_dice_tokens "ds1" \ | ||
--surface_dice_tokens "ds2"' | ||
), | ||
] = None, | ||
use_deep_supervision: bool = Option( # noqa: B008 | ||
True, help='Whether to use deep supervision. Pass "True" or "False".' | ||
), | ||
|
@@ -119,6 +138,12 @@ def train_advanced( | |
If set to False, data augmentation still happens, but not as frequently. | ||
More data augmentation can lead to a better performance, but also increases the | ||
training time substantially. | ||
use_surface_dice : bool | ||
Determines whether to use Surface-Dice loss, by default True. | ||
surface_dice_weight : float | ||
Scaling factor for the Surface-Dice loss, by default 1.0. | ||
surface_dice_tokens : list | ||
List of tokens to use for the Surface-Dice loss, by default ["all"]. | ||
use_deep_supervision : bool | ||
Determines whether to use deep supervision, by default True. | ||
project_name : str | ||
|
@@ -140,6 +165,9 @@ def train_advanced( | |
max_epochs=max_epochs, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
use_surf_dice=use_surface_dice, | ||
surf_dice_weight=surface_dice_weight, | ||
surf_dice_tokens=surface_dice_tokens, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
import os | ||
from typing import Dict | ||
|
||
# from skimage import io | ||
import imageio as io | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
@@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: | |
"label": np.expand_dims(self.labels[idx], 0), | ||
} | ||
idx_dict = self.transforms(idx_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataset token is now returned with every train image |
||
idx_dict["dataset"] = self.dataset_labels[idx] | ||
return idx_dict | ||
|
||
def __len__(self) -> int: | ||
|
@@ -126,6 +126,7 @@ def load_data(self) -> None: | |
print("Loading images into dataset.") | ||
self.imgs = [] | ||
self.labels = [] | ||
self.dataset_labels = [] | ||
for entry in self.data_paths: | ||
label = read_nifti( | ||
entry[1] | ||
|
@@ -137,6 +138,7 @@ def load_data(self) -> None: | |
img = np.transpose(img, (1, 2, 0)) | ||
self.imgs.append(img) | ||
self.labels.append(label) | ||
self.dataset_labels.append(get_dataset_token(entry[0])) | ||
|
||
def initialize_imgs_paths(self) -> None: | ||
""" | ||
|
@@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None: | |
os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"), | ||
test_sample["label"][1][0, :, :, num_mask], | ||
) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataset token is defined as first token before 1st underscore |
||
def get_dataset_token(patch_name): | ||
""" | ||
Get the dataset token from the patch name. | ||
|
||
Parameters | ||
---------- | ||
patch_name : str | ||
The name of the patch. | ||
|
||
Returns | ||
------- | ||
str | ||
The dataset token. | ||
|
||
""" | ||
basename = os.path.basename(patch_name) | ||
dataset_token = basename.split("_")[0] | ||
return dataset_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset token is now readable as well. This helps to distinguish between different datasets, because we may want to apply different loss functions (particularly Surface-Dice) to some datasets, but not to others.