diff --git a/documentation/competitions/AutoPETII.md b/documentation/competitions/AutoPETII.md index 075256a..f15ec5b 100644 --- a/documentation/competitions/AutoPETII.md +++ b/documentation/competitions/AutoPETII.md @@ -46,7 +46,7 @@ Add the following to the 'configurations' dict in 'nnUNetPlans.json': ```json "3d_fullres_resenc": { "inherits_from": "3d_fullres", - "UNet_class_name": "ResidualEncoderUNet", + "network_arch_class_name": "ResidualEncoderUNet", "n_conv_per_stage_encoder": [ 1, 3, diff --git a/documentation/dataset_format.md b/documentation/dataset_format.md index de6c993..cd8433a 100644 --- a/documentation/dataset_format.md +++ b/documentation/dataset_format.md @@ -26,7 +26,8 @@ T2 MRI, …) and FILE_ENDING is the file extension used by your image format (.p The dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details). Side note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier. -Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the [road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). +Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the +[road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). **Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are integer maps with each value representing a semantic class. The background must be 0. If there is no background, then @@ -57,14 +58,14 @@ of what the raw data was provided in! This is for performance reasons. By default, the following file formats are supported: + - NaturalImage2DIO: .png, .bmp, .tif - NibabelIO: .nii.gz, .nrrd, .mha - NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS! - SimpleITKIO: .nii.gz, .nrrd, .mha - Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information, -nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains three numbers -(no units, no comma. Just separated by whitespace), one for each dimension. - +nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains this information (see +[here](#datasetjson)). The file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK support more than the three given here. The file endings given here are just the ones we tested! @@ -200,6 +201,27 @@ There is a utility with which you can generate the dataset.json automatically. Y [here](../nnunetv2/dataset_conversion/generate_dataset_json.py). See our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation! +As described above, a json file that contains spacing information is required for TIFF files. +An example for a 3D TIFF stack with units corresponding to 7.6 in x and y, 80 in z is: + +``` +{ + "spacing": [7.6, 7.6, 80.0] +} +``` + +Within the dataset folder, this file (named `cell6.json` in this example) would be placed in the following folders: + + nnUNet_raw/Dataset123_Foo/ + ├── dataset.json + ├── imagesTr + │   ├── cell6.json + │   └── cell6_0000.tif + └── labelsTr + ├── cell6.json + └── cell6.tif + + ## How to use nnU-Net v1 Tasks If you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`! diff --git a/documentation/explanation_plans_files.md b/documentation/explanation_plans_files.md index 00f1216..13ccda8 100644 --- a/documentation/explanation_plans_files.md +++ b/documentation/explanation_plans_files.md @@ -74,7 +74,7 @@ nnunetv2.preprocessing.resampling resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling - `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg -- `UNet_class_name`: UNet class name, can be used to integrate custom dynamic architectures +- `network_arch_class_name`: UNet class name, can be used to integrate custom dynamic architectures - `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features are doubled with each downsampling - `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to diff --git a/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/nnunetv2/batch_running/collect_results_custom_Decathlon.py index b670661..77e7dfb 100644 --- a/nnunetv2/batch_running/collect_results_custom_Decathlon.py +++ b/nnunetv2/batch_running/collect_results_custom_Decathlon.py @@ -94,21 +94,19 @@ def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[st if __name__ == '__main__': use_these_trainers = { - 'nnUNetTrainer': ('nnUNetPlans',), - 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), - 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetPlans', 'nnUNetResEncUNet2Plans', 'nnUNetResBottleneckEncUNetPlans', 'nnUNetResUNetPlans', 'nnUNetResUNet2Plans', 'nnUNetResUNet3Plans', 'nnUNetDeeperResBottleneckEncUNetPlans'), } all_results_file= join(nnUNet_results, 'customDecResults.csv') - datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] + datasets = [2, 3, 4, 17, 24, 27, 38, 55, 137, 217, 220, 221, 223] # amos post challenge, kits2023 collect_results(use_these_trainers, datasets, all_results_file) folds = (0, 1, 2, 3, 4) - configs = ("3d_fullres", "3d_lowres") + configs = ("3d_fullres", ) output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) folds = (0, ) - configs = ("3d_fullres", "3d_lowres") + configs = ("3d_fullres", ) output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) diff --git a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py index 0a75fbd..7f9726e 100644 --- a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py +++ b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -21,18 +21,18 @@ def merge(dict1, dict2): # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of # datasets for evaluation and future development configurations_all = { - 2: ("3d_fullres", "2d"), - 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 4: ("2d", "3d_fullres"), + # 2: ("3d_fullres", "2d"), + # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 4: ("2d", "3d_fullres"), 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 20: ("2d", "3d_fullres"), - 24: ("2d", "3d_fullres"), - 27: ("2d", "3d_fullres"), - 38: ("2d", "3d_fullres"), - 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 82: ("2d", "3d_fullres"), - # 83: ("2d", "3d_fullres"), + # 24: ("2d", "3d_fullres"), + # 27: ("2d", "3d_fullres"), + # 38: ("2d", "3d_fullres"), + # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 137: ("2d", "3d_fullres"), + 220: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 223: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), } configurations_3d_fr_only = { @@ -52,25 +52,23 @@ def merge(dict1, dict2): } num_gpus = 1 - exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\" -R \"select[hname!='e230-dgxa100-1']\" -R \"select[hname!='e230-dgxa100-2']\" -R \"select[hname!='e230-dgxa100-3']\" -R \"select[hname!='e230-dgxa100-4']\"" - resources = "-R \"tensorcore\"" + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" + resources = "" gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G" - queue = "-q gpu-lowprio" - preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " - train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_release nnUNetv2_train' + queue = "-q gpu" + preamble = "-L /bin/bash \"source ~/load_env_mamba_slumber.sh && " + train_command = 'nnUNetv2_train' - folds = (0, ) + folds = (1, 2, 3, 4) # use_this = configurations_2d_only - use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) + use_this = configurations_3d_fr_only # use_this = merge(use_this, configurations_3d_c_only) use_these_modules = { - 'nnUNetTrainer': ('nnUNetPlans',), - 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), - # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetMPlans', 'nnUNetResEncUNetLPlans', 'nnUNetResEncUNetXLPlans'), } - additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' + additional_arguments = f' -num_gpus {num_gpus}' # '' output_file = "/home/isensee/deleteme.txt" with open(output_file, 'w') as f: diff --git a/nnunetv2/dataset_conversion/Dataset027_ACDC.py b/nnunetv2/dataset_conversion/Dataset027_ACDC.py index 569ff6f..8ebc251 100644 --- a/nnunetv2/dataset_conversion/Dataset027_ACDC.py +++ b/nnunetv2/dataset_conversion/Dataset027_ACDC.py @@ -1,9 +1,12 @@ import os import shutil from pathlib import Path +from typing import List +from batchgenerators.utilities.file_and_folder_operations import nifti_files, join, maybe_mkdir_p, save_json from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json -from nnunetv2.paths import nnUNet_raw +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +import numpy as np def make_out_dirs(dataset_id: int, task_name="ACDC"): @@ -22,6 +25,22 @@ def make_out_dirs(dataset_id: int, task_name="ACDC"): return out_dir, out_train_dir, out_labels_dir, out_test_dir +def create_ACDC_split(labelsTr_folder: str, seed: int = 1234) -> List[dict[str, List]]: + # labelsTr_folder = '/home/isensee/drives/gpu_data_root/OE0441/isensee/nnUNet_raw/nnUNet_raw_remake/Dataset027_ACDC/labelsTr' + nii_files = nifti_files(labelsTr_folder, join=False) + patients = np.unique([i[:len('patient000')] for i in nii_files]) + rs = np.random.RandomState(seed) + rs.shuffle(patients) + splits = [] + for fold in range(5): + val_patients = patients[fold::5] + train_patients = [i for i in patients if i not in val_patients] + val_cases = [i[:-7] for i in nii_files for j in val_patients if i.startswith(j)] + train_cases = [i[:-7] for i in nii_files for j in train_patients if i.startswith(j)] + splits.append({'train': train_cases, 'val': val_cases}) + return splits + + def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path): """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.""" patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()]) @@ -84,4 +103,12 @@ def convert_acdc(src_data_folder: str, dataset_id=27): args = parser.parse_args() print("Converting...") convert_acdc(args.input_folder, args.dataset_id) + + dataset_name = f"Dataset{args.dataset_id:03d}_{'ACDC'}" + labelsTr = join(nnUNet_raw, dataset_name, 'labelsTr') + preprocessed_folder = join(nnUNet_preprocessed, dataset_name) + maybe_mkdir_p(preprocessed_folder) + split = create_ACDC_split(labelsTr) + save_json(split, join(preprocessed_folder, 'splits_final.json'), sort_keys=False) + print("Done!") diff --git a/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py b/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py index 20a794c..7f0d0e9 100644 --- a/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py +++ b/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py @@ -31,7 +31,7 @@ def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220): regions_class_order=(1, 3, 2), num_training_cases=len(cases), file_ending='.nii.gz', dataset_name=task_name, reference='none', - release='prerelease', + release='0.1.3', overwrite_image_reader_writer='NibabelIOWithReorient', description="KiTS2023") diff --git a/nnunetv2/dataset_conversion/Dataset223_AMOS2022postChallenge.py b/nnunetv2/dataset_conversion/Dataset223_AMOS2022postChallenge.py new file mode 100644 index 0000000..cded73d --- /dev/null +++ b/nnunetv2/dataset_conversion/Dataset223_AMOS2022postChallenge.py @@ -0,0 +1,59 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.paths import nnUNet_raw +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json + +if __name__ == '__main__': + downloaded_amos_dir = '/home/isensee/amos22/amos22' # downloaded and extracted from https://zenodo.org/record/7155725#.Y0OOCOxBztM + + target_dataset_id = 223 + target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AMOS2022postChallenge' + + maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) + imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') + imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs') + labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') + maybe_mkdir_p(imagesTr) + maybe_mkdir_p(imagesTs) + maybe_mkdir_p(labelsTr) + + train_identifiers = [] + # copy images + source = join(downloaded_amos_dir, 'imagesTr') + source_files = nifti_files(source, join=False) + train_identifiers += source_files + for s in source_files: + shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz')) + + source = join(downloaded_amos_dir, 'imagesVa') + source_files = nifti_files(source, join=False) + train_identifiers += source_files + for s in source_files: + shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz')) + + source = join(downloaded_amos_dir, 'imagesTs') + source_files = nifti_files(source, join=False) + for s in source_files: + shutil.copy(join(source, s), join(imagesTs, s[:-7] + '_0000.nii.gz')) + + # copy labels + source = join(downloaded_amos_dir, 'labelsTr') + source_files = nifti_files(source, join=False) + for s in source_files: + shutil.copy(join(source, s), join(labelsTr, s)) + + source = join(downloaded_amos_dir, 'labelsVa') + source_files = nifti_files(source, join=False) + for s in source_files: + shutil.copy(join(source, s), join(labelsTr, s)) + + old_dataset_json = load_json(join(downloaded_amos_dir, 'dataset.json')) + new_labels = {v: k for k, v in old_dataset_json['labels'].items()} + + generate_dataset_json(join(nnUNet_raw, target_dataset_name), {0: 'nonCT'}, new_labels, + num_training_cases=len(train_identifiers), file_ending='.nii.gz', regions_class_order=None, + dataset_name=target_dataset_name, reference='https://zenodo.org/record/7155725#.Y0OOCOxBztM', + license=old_dataset_json['licence'], # typo in OG dataset.json + description=old_dataset_json['description'], + release=old_dataset_json['release']) diff --git a/nnunetv2/evaluation/evaluate_predictions.py b/nnunetv2/evaluation/evaluate_predictions.py index 80e4d24..18f0df9 100644 --- a/nnunetv2/evaluation/evaluate_predictions.py +++ b/nnunetv2/evaluation/evaluate_predictions.py @@ -33,7 +33,7 @@ def key_to_label_or_region(key: str): def save_summary_json(results: dict, output_file: str): """ - stupid json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit + json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit ourselves """ results_converted = deepcopy(results) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 2b1c412..c0ac2d3 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -1,12 +1,11 @@ -import os.path import shutil from copy import deepcopy -from functools import lru_cache -from typing import List, Union, Tuple, Type +from typing import List, Union, Tuple import numpy as np +import torch from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p -from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.architectures.unet import PlainConvUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from nnunetv2.configuration import ANISO_THRESHOLD @@ -16,9 +15,10 @@ from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.json_export import recursive_fix_for_json_export -from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ - get_filenames_of_train_images_and_targets +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets class ExperimentPlanner(object): @@ -57,13 +57,16 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_reference_val_corresp_GB = 8 self.UNet_reference_val_corresp_bs_2d = 12 self.UNet_reference_val_corresp_bs_3d = 2 - self.UNet_vram_target_GB = gpu_memory_target_in_gb self.UNet_featuremap_min_edge_length = 4 self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) self.UNet_min_batch_size = 2 self.UNet_max_features_2d = 512 self.UNet_max_features_3d = 320 + self.max_dataset_covered = 0.05 # we limit the batch size so that no more than 5% of the dataset can be seen + # in a single forward/backward pass + + self.UNet_vram_target_GB = gpu_memory_target_in_gb self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the # median shape then we need a lowres config as well @@ -79,37 +82,33 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.plans = None + if isfile(join(self.raw_dataset_folder, 'splits_final.json')): + _maybe_copy_splits_file(join(self.raw_dataset_folder, 'splits_final.json'), + join(preprocessed_folder, 'splits_final.json')) + def determine_reader_writer(self): example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0] return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) @staticmethod - @lru_cache(maxsize=None) def static_estimate_VRAM_usage(patch_size: Tuple[int], - n_stages: int, - strides: Union[int, List[int], Tuple[int, ...]], - UNet_class: Union[Type[PlainConvUNet], Type[ResidualEncoderUNet]], - num_input_channels: int, - features_per_stage: Tuple[int], - blocks_per_stage_encoder: Union[int, Tuple[int]], - blocks_per_stage_decoder: Union[int, Tuple[int]], - num_labels: int): + input_channels: int, + output_channels: int, + arch_class_name: str, + arch_kwargs: dict, + arch_kwargs_req_import: Tuple[str, ...]): """ Works for PlainConvUNet, ResidualEncoderUNet """ - dim = len(patch_size) - conv_op = convert_dim_to_conv_op(dim) - norm_op = get_matching_instancenorm(conv_op) - net = UNet_class(num_input_channels, n_stages, - features_per_stage, - conv_op, - 3, - strides, - blocks_per_stage_encoder, - num_labels, - blocks_per_stage_decoder, - norm_op=norm_op) - return net.compute_conv_feature_map_size(patch_size) + a = torch.get_num_threads() + torch.set_num_threads(get_allowed_n_proc_DA()) + # print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}') + net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, + output_channels, + allow_init=False) + ret = net.compute_conv_feature_map_size(patch_size) + torch.set_num_threads(a) + return ret def determine_resampling(self, *args, **kwargs): """ @@ -228,10 +227,24 @@ def determine_transpose(self): def get_plans_for_configuration(self, spacing: Union[np.ndarray, Tuple[float, ...], List[float]], - median_shape: Union[np.ndarray, Tuple[int, ...], List[int]], + median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, - approximate_n_voxels_dataset: float) -> dict: + approximate_n_voxels_dataset: float, + _cache: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + # print(spacing, median_shape, approximate_n_voxels_dataset) # find an initial patch size # we first use the spacing to get an aspect ratio @@ -260,34 +273,56 @@ def get_plans_for_configuration(self, shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, self.UNet_featuremap_min_edge_length, 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } # now estimate vram consumption - num_stages = len(pool_op_kernel_sizes) - estimate = self.static_estimate_VRAM_usage(tuple(patch_size), - num_stages, - tuple([tuple(i) for i in pool_op_kernel_sizes]), - self.UNet_class, - len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()), - tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else - self.UNet_max_features_3d, - self.UNet_reference_com_nfeatures * 2 ** i) for - i in range(len(pool_op_kernel_sizes))]), - self.UNet_blocks_per_stage_encoder[:num_stages], - self.UNet_blocks_per_stage_decoder[:num_stages - 1], - len(self.dataset_json['labels'].keys())) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) - while estimate > reference: - # print(patch_size) + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + # we enforce a batch size of at least two, reference values may have been computed for different batch sizes. + # Correct for that in the while loop if statement + while (estimate / ref_bs * 2) > reference: + # print(patch_size, estimate, reference) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) - axis_to_be_reduced = np.argsort(patch_size / median_shape[:len(spacing)])[-1] + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. @@ -295,6 +330,7 @@ def get_plans_for_configuration(self, # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first # subtract shape_must_be_divisible_by, then recompute it and then subtract the # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) tmp = deepcopy(patch_size) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by = \ @@ -310,30 +346,35 @@ def get_plans_for_configuration(self, 999999) num_stages = len(pool_op_kernel_sizes) - estimate = self.static_estimate_VRAM_usage(tuple(patch_size), - num_stages, - tuple([tuple(i) for i in pool_op_kernel_sizes]), - self.UNet_class, - len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()), - tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else - self.UNet_max_features_3d, - self.UNet_reference_com_nfeatures * 2 ** i) for - i in range(len(pool_op_kernel_sizes))]), - self.UNet_blocks_per_stage_encoder[:num_stages], - self.UNet_blocks_per_stage_decoder[:num_stages - 1], - len(self.dataset_json['labels'].keys())) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size - ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d batch_size = round((reference / estimate) * ref_bs) # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() @@ -341,7 +382,7 @@ def get_plans_for_configuration(self, normalization_schemes, mask_is_used_for_norm = \ self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() - num_stages = len(pool_op_kernel_sizes) + plan = { 'data_identifier': data_identifier, 'preprocessor_name': self.preprocessor_name, @@ -351,20 +392,13 @@ def get_plans_for_configuration(self, 'spacing': spacing, 'normalization_schemes': normalization_schemes, 'use_mask_for_norm': mask_is_used_for_norm, - 'UNet_class_name': self.UNet_class.__name__, - 'UNet_base_num_features': self.UNet_base_num_features, - 'n_conv_per_stage_encoder': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'num_pool_per_axis': network_num_pool_per_axis, - 'pool_op_kernel_sizes': pool_op_kernel_sizes, - 'conv_kernel_sizes': conv_kernel_sizes, - 'unet_max_num_features': self.UNet_max_features_3d if len(spacing) == 3 else self.UNet_max_features_2d, 'resampling_fn_data': resampling_data.__name__, 'resampling_fn_seg': resampling_seg.__name__, 'resampling_fn_data_kwargs': resampling_data_kwargs, 'resampling_fn_seg_kwargs': resampling_seg_kwargs, 'resampling_fn_probabilities': resampling_softmax.__name__, 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs } return plan @@ -379,6 +413,8 @@ def plan_experiment(self): So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too hard. """ + # we use this as a cache to prevent having to instantiate the architecture too often. Saves computation time + _tmp = {} # first get transpose transpose_forward, transpose_backward = self.determine_transpose() @@ -400,7 +436,7 @@ def plan_experiment(self): plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed, new_median_shape_transposed, self.generate_data_identifier('3d_fullres'), - approximate_n_voxels_dataset) + approximate_n_voxels_dataset, _tmp) # maybe add 3d_lowres as well patch_size_fullres = plan_3d_fullres['patch_size'] median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64) @@ -410,7 +446,6 @@ def plan_experiment(self): lowres_spacing = deepcopy(plan_3d_fullres['spacing']) spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation! - while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold: # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they # is/are similar (factor 2) to the other ax(i/e)s. @@ -423,16 +458,21 @@ def plan_experiment(self): dtype=np.float64) # print(lowres_spacing) plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing, - [round(i) for i in plan_3d_fullres['spacing'] / - lowres_spacing * new_median_shape_transposed], + tuple([round(i) for i in plan_3d_fullres['spacing'] / + lowres_spacing * new_median_shape_transposed]), self.generate_data_identifier('3d_lowres'), float(np.prod(median_num_voxels) * - self.dataset_json['numTraining'])) + self.dataset_json['numTraining']), _tmp) num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64) print(f'Attempting to find 3d_lowres config. ' f'\nCurrent spacing: {lowres_spacing}. ' f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. ' f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}') + if np.prod(new_median_shape_transposed, dtype=np.float64) / median_num_voxels < 2: + print(f'Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. ' + f'3d_fullres: {new_median_shape_transposed}, ' + f'3d_lowres: {[round(i) for i in plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed]}') + plan_3d_lowres = None if plan_3d_lowres is not None: plan_3d_lowres['batch_dice'] = False plan_3d_fullres['batch_dice'] = True @@ -445,7 +485,8 @@ def plan_experiment(self): # 2D configuration plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], new_median_shape_transposed[1:], - self.generate_data_identifier('2d'), approximate_n_voxels_dataset) + self.generate_data_identifier('2d'), approximate_n_voxels_dataset, + _tmp) plan_2d['batch_dice'] = True print('2D U-Net configuration:') @@ -461,7 +502,7 @@ def plan_experiment(self): shutil.copy(join(self.raw_dataset_folder, 'dataset.json'), join(nnUNet_preprocessed, self.dataset_name, 'dataset.json')) - # json is stupid and I hate it... "Object of type int64 is not JSON serializable" -> my ass + # json is ###. I hate it... "Object of type int64 is not JSON serializable" plans = { 'dataset_name': self.dataset_name, 'plans_name': self.plans_identifier, @@ -530,5 +571,23 @@ def load_plans(self, fname: str): self.plans = load_json(fname) +def _maybe_copy_splits_file(splits_file: str, target_fname: str): + if not isfile(target_fname): + shutil.copy(splits_file, target_fname) + else: + # split already exists, do not copy, but check that the splits match. + # This code allows target_fname to contain more splits than splits_file. This is OK. + splits_source = load_json(splits_file) + splits_target = load_json(target_fname) + # all folds in the source file must match the target file + for i in range(len(splits_source)): + train_source = set(splits_source[i]['train']) + train_target = set(splits_target[i]['train']) + assert train_target == train_source + val_source = set(splits_source[i]['val']) + val_target = set(splits_target[i]['val']) + assert val_source == val_target + + if __name__ == '__main__': ExperimentPlanner(2, 8).plan_experiment() diff --git a/nnunetv2/experiment_planning/experiment_planners/network_topology.py b/nnunetv2/experiment_planning/experiment_planners/network_topology.py index 1ce6a46..6922f7b 100644 --- a/nnunetv2/experiment_planning/experiment_planners/network_topology.py +++ b/nnunetv2/experiment_planning/experiment_planners/network_topology.py @@ -100,6 +100,9 @@ def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpo must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) patch_size = pad_shape(patch_size, must_be_divisible_by) + def _to_tuple(lst): + return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst) + # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here conv_kernel_sizes.append([3]*dim) - return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by + return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py index 52ca938..0ed9532 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -1,9 +1,14 @@ +import numpy as np +from copy import deepcopy from typing import Union, List, Tuple +from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props class ResEncUNetPlanner(ExperimentPlanner): @@ -14,23 +19,200 @@ def __init__(self, dataset_name_or_id: Union[str, int], suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) - - self.UNet_base_num_features = 32 self.UNet_class = ResidualEncoderUNet # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as # much as possible self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 - self.UNet_reference_com_nfeatures = 32 - self.UNet_reference_val_corresp_GB = 8 - self.UNet_reference_val_corresp_bs_2d = 12 - self.UNet_reference_val_corresp_bs_3d = 2 - self.UNet_featuremap_min_edge_length = 4 self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) - self.UNet_min_batch_size = 2 - self.UNet_max_features_2d = 512 - self.UNet_max_features_3d = 320 + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + if configuration_name == '2d' or configuration_name == '3d_fullres': + # we do not deviate from ExperimentPlanner so we can reuse its data + return 'nnUNetPlans' + '_' + configuration_name + else: + return self.plans_identifier + '_' + configuration_name + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...]], + data_identifier: str, + approximate_n_voxels_dataset: float, + _cache: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } + + # now estimate vram consumption + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs + } + return plan if __name__ == '__main__': @@ -51,4 +233,3 @@ def __init__(self, dataset_name_or_id: Union[str, int], conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 - diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnunetv2/experiment_planning/plan_and_preprocess_api.py index eb94840..c81e06a 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_api.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -1,17 +1,16 @@ -import shutil from typing import List, Type, Optional, Tuple, Union -import nnunetv2 -from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, subfiles, load_json +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json +import nnunetv2 +from nnunetv2.configuration import default_num_processes from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed -from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name, maybe_convert_to_dataset_name +from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.plans_handling.plans_handler import PlansManager -from nnunetv2.configuration import default_num_processes from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets @@ -52,21 +51,24 @@ def plan_experiment_dataset(dataset_id: int, experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner, gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor', overwrite_target_spacing: Optional[Tuple[float, ...]] = None, - overwrite_plans_name: Optional[str] = None) -> dict: + overwrite_plans_name: Optional[str] = None) -> Tuple[dict, str]: """ overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! """ kwargs = {} if overwrite_plans_name is not None: kwargs['plans_name'] = overwrite_plans_name - return experiment_planner_class(dataset_id, - gpu_memory_target_in_gb=gpu_memory_target_in_gb, - preprocessor_name=preprocess_class_name, - overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if - overwrite_target_spacing is not None else overwrite_target_spacing, - suppress_transpose=False, # might expose this later, - **kwargs - ).plan_experiment() + + planner = experiment_planner_class(dataset_id, + gpu_memory_target_in_gb=gpu_memory_target_in_gb, + preprocessor_name=preprocess_class_name, + overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if + overwrite_target_spacing is not None else overwrite_target_spacing, + suppress_transpose=False, # might expose this later, + **kwargs + ) + ret = planner.plan_experiment() + return ret, planner.plans_identifier def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner', @@ -79,9 +81,12 @@ def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), experiment_planner_class_name, current_module="nnunetv2.experiment_planning") + plans_identifier = None for d in dataset_ids: - plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, preprocess_class_name, - overwrite_target_spacing, overwrite_plans_name) + _, plans_identifier = plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, + preprocess_class_name, + overwrite_target_spacing, overwrite_plans_name) + return plans_identifier def preprocess_dataset(dataset_id: int, @@ -128,7 +133,6 @@ def preprocess_dataset(dataset_id: int, update=True) - def preprocess(dataset_ids: List[int], plans_identifier: str = 'nnUNetPlans', configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py b/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py index 556f04a..88a37f0 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py @@ -149,7 +149,7 @@ def plan_and_preprocess_entry(): 'know what you are doing and NEVER use this without running the default nnU-Net first ' '(as a baseline). Changing the target spacing for the other configurations is currently ' 'not implemented. New target spacing must be a list of three numbers!') - parser.add_argument('-overwrite_plans_name', default='nnUNetPlans', required=False, + parser.add_argument('-overwrite_plans_name', default=None, required=False, help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, ' '-preprocessor_name or ' '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' @@ -183,7 +183,7 @@ def plan_and_preprocess_entry(): # experiment planning print('Experiment planning...') - plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) + plans_identifier = plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) # manage default np if args.np is None: @@ -194,7 +194,7 @@ def plan_and_preprocess_entry(): # preprocessing if not args.no_pp: print('Preprocessing...') - preprocess(args.d, args.overwrite_plans_name, args.c, np, args.verbose) + preprocess(args.d, plans_identifier, args.c, np, args.verbose) if __name__ == '__main__': diff --git a/nnunetv2/experiment_planning/verify_dataset_integrity.py b/nnunetv2/experiment_planning/verify_dataset_integrity.py index 61175d0..71f84bf 100644 --- a/nnunetv2/experiment_planning/verify_dataset_integrity.py +++ b/nnunetv2/experiment_planning/verify_dataset_integrity.py @@ -76,7 +76,7 @@ def check_cases(image_files: List[str], label_file: str, expected_num_channels: if not np.allclose(spacing_seg, spacing_images): print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. ' '\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' % - (shape_image, shape_seg, image_files, label_file)) + (spacing_images, spacing_seg, image_files, label_file)) ret = False # check modalities diff --git a/nnunetv2/imageio/natural_image_reager_writer.py b/nnunetv2/imageio/natural_image_reader_writer.py similarity index 100% rename from nnunetv2/imageio/natural_image_reager_writer.py rename to nnunetv2/imageio/natural_image_reader_writer.py diff --git a/nnunetv2/imageio/reader_writer_registry.py b/nnunetv2/imageio/reader_writer_registry.py index e2921e6..606334c 100644 --- a/nnunetv2/imageio/reader_writer_registry.py +++ b/nnunetv2/imageio/reader_writer_registry.py @@ -4,7 +4,7 @@ from batchgenerators.utilities.file_and_folder_operations import join import nnunetv2 -from nnunetv2.imageio.natural_image_reager_writer import NaturalImage2DIO +from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO from nnunetv2.imageio.tif_reader_writer import Tiff3DIO diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index 9dfee4e..1777fb9 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -53,6 +53,7 @@ def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], pass done_event.set() except Exception as e: + # print(Exception, e) abort_event.set() raise e @@ -99,6 +100,7 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], worker_ctr = 0 while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + # import IPython;IPython.embed() if not target_queues[worker_ctr].empty(): item = target_queues[worker_ctr].get() worker_ctr = (worker_ctr + 1) % num_processes diff --git a/nnunetv2/inference/examples.py b/nnunetv2/inference/examples.py index b57a398..a66d98f 100644 --- a/nnunetv2/inference/examples.py +++ b/nnunetv2/inference/examples.py @@ -12,7 +12,7 @@ tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1fe69ea..8015ed7 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -1,7 +1,7 @@ import inspect +import itertools import multiprocessing import os -import traceback from copy import deepcopy from time import sleep from typing import Tuple, Union, List, Optional @@ -39,7 +39,7 @@ def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, - perform_everything_on_gpu: bool = True, + perform_everything_on_device: bool = True, device: torch.device = torch.device('cuda'), verbose: bool = False, verbose_preprocessing: bool = False, @@ -56,13 +56,12 @@ def __init__(self, self.use_mirroring = use_mirroring if device.type == 'cuda': # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES! - # why would I ever want to do that. Stupid dobby. This kills DDP inference... pass if device.type != 'cuda': - print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False') - perform_everything_on_gpu = False + print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') + perform_everything_on_device = False self.device = device - self.perform_everything_on_gpu = perform_everything_on_gpu + self.perform_everything_on_device = perform_everything_on_device def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], @@ -98,8 +97,16 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') - network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager, - num_input_channels, enable_deep_supervision=False) + + network = trainer_class.build_network_architecture( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + plans_manager.get_label_manager(dataset_json).num_segmentation_heads, + enable_deep_supervision=False + ) + self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.list_of_parameters = parameters @@ -110,7 +117,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ and not isinstance(self.network, OptimizedModule): - print('compiling network') + print('Using torch.compile') self.network = torch.compile(self.network) def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, @@ -129,12 +136,13 @@ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) allow_compile = True - allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) + allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and ( + os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) if isinstance(self.network, DistributedDataParallel): allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) if allow_compile: - print('compiling network') + print('Using torch.compile') self.network = torch.compile(self.network) @staticmethod @@ -352,7 +360,7 @@ def predict_from_data_iterator(self, else: print(f'\nPredicting image of shape {data.shape}:') - print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}') + print(f'perform_everything_on_device: {self.perform_everything_on_device}') properties = preprocessed['data_properties'] @@ -360,7 +368,6 @@ def predict_from_data_iterator(self, # npy files proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) while not proceed: - # print('sleeping') sleep(0.1) proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) @@ -368,8 +375,8 @@ def predict_from_data_iterator(self, if ofile is not None: # this needs to go into background processes - # export_prediction_from_logits(prediction, properties, configuration_manager, plans_manager, - # dataset_json, ofile, save_probabilities) + # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, + # self.dataset_json, ofile, save_probabilities) print('sending off prediction to background worker for resampling and export') r.append( export_pool.starmap_async( @@ -379,10 +386,12 @@ def predict_from_data_iterator(self, ) ) else: - # convert_predicted_logits_to_segmentation_with_correct_shape(prediction, plans_manager, - # configuration_manager, label_manager, - # properties, - # save_probabilities) + # convert_predicted_logits_to_segmentation_with_correct_shape( + # prediction, self.plans_manager, + # self.configuration_manager, self.label_manager, + # properties, + # save_probabilities) + print('sending off prediction to background worker for resampling') r.append( export_pool.starmap_async( @@ -453,56 +462,33 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. SEE convert_predicted_logits_to_segmentation_with_correct_shape """ - # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as - # default and not have the entire program crash in case of GPU out of memory. Neat. That should make - # things a lot faster for some datasets. - original_perform_everything_on_gpu = self.perform_everything_on_gpu + n_threads = torch.get_num_threads() + torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) with torch.no_grad(): prediction = None - if self.perform_everything_on_gpu: - try: - for params in self.list_of_parameters: - - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) - - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data) - else: - prediction += self.predict_sliding_window_return_logits(data) - - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) - - except RuntimeError: - print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. ' - 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...') - print('Error:') - traceback.print_exc() - prediction = None - self.perform_everything_on_gpu = False - - if prediction is None: - for params in self.list_of_parameters: - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) - - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data) - else: - prediction += self.predict_sliding_window_return_logits(data) - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) - - print('Prediction done, transferring to CPU if needed') + + for params in self.list_of_parameters: + + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) + + # why not leave prediction on device if perform_everything_on_device? Because this may cause the + # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than + # this actually saves computation time + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data).to('cpu') + else: + prediction += self.predict_sliding_window_return_logits(data).to('cpu') + + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) + + if self.verbose: print('Prediction done') prediction = prediction.to('cpu') - self.perform_everything_on_gpu = original_perform_everything_on_gpu + torch.set_num_threads(n_threads) return prediction def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): @@ -548,24 +534,66 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' - num_predictons = 2 ** len(mirror_axes) - if 0 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,)) - if 1 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,)) - if 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,)) - if 0 in mirror_axes and 1 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3)) - if 0 in mirror_axes and 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4)) - if 1 in mirror_axes and 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4)) - if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4)) - prediction /= num_predictons + axes_combinations = [ + c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1) + ] + for axes in axes_combinations: + prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,)) + prediction /= (len(axes_combinations) + 1) return prediction + def _internal_predict_sliding_window_return_logits(self, + data: torch.Tensor, + slicers, + do_on_device: bool = True, + ): + predicted_logits = n_predictions = prediction = gaussian = workon = None + results_device = self.device if do_on_device else torch.device('cpu') + + try: + empty_cache(self.device) + + # move data to device + if self.verbose: + print(f'move image to device {results_device}') + data = data.to(results_device) + + # preallocate arrays + if self.verbose: + print(f'preallocating results arrays on device {results_device}') + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=10, + device=results_device) + + if self.verbose: print('running prediction') + if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps') + for sl in tqdm(slicers, disable=not self.allow_tqdm): + workon = data[sl][None] + workon = workon.to(self.device, non_blocking=False) + + prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) + + predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) + n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + + predicted_logits /= n_predictions + # check for infs + if torch.any(torch.isinf(predicted_logits)): + raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' + 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' + 'predicted_logits to fp32') + except Exception as e: + del predicted_logits, n_predictions, prediction, gaussian, workon + empty_cache(self.device) + empty_cache(results_device) + raise e + return predicted_logits + def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ -> Union[np.ndarray, torch.Tensor]: assert isinstance(input_image, torch.Tensor) @@ -574,7 +602,7 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ empty_cache(self.device) - # Autocast is a little bitch. + # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) # and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False @@ -595,49 +623,24 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) - # preallocate results and num_predictions - results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu') - if self.verbose: print('preallocating arrays') - try: - data = data.to(self.device) - predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, - device=results_device) - if self.use_gaussian: - gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, - value_scaling_factor=10, - device=results_device) - except RuntimeError: - # sometimes the stuff is too large for GPUs. In that case fall back to CPU - results_device = torch.device('cpu') - data = data.to(results_device) - predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, - device=results_device) - if self.use_gaussian: - gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, - value_scaling_factor=10, - device=results_device) - finally: - empty_cache(self.device) - - if self.verbose: print('running prediction') - for sl in tqdm(slicers, disable=not self.allow_tqdm): - workon = data[sl][None] - workon = workon.to(self.device, non_blocking=False) - - prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) - - predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) - n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) - - predicted_logits /= n_predictions - empty_cache(self.device) - return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + if self.perform_everything_on_device and self.device != 'cpu': + # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device + try: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, + self.perform_everything_on_device) + except RuntimeError: + print( + 'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') + empty_cache(self.device) + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) + else: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, + self.perform_everything_on_device) + + empty_cache(self.device) + # revert padding + predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + return predicted_logits def predict_entry_point_modelfolder(): @@ -685,6 +688,9 @@ def predict_entry_point_modelfolder(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') print( "\n#######################################################################\nPlease cite the following paper " @@ -717,9 +723,11 @@ def predict_entry_point_modelfolder(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=device, - verbose=args.verbose) + verbose=args.verbose, + allow_tqdm=not args.disable_progress_bar, + verbose_preprocessing=args.verbose) predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, overwrite=not args.continue_prediction, @@ -789,6 +797,9 @@ def predict_entry_point(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') print( "\n#######################################################################\nPlease cite the following paper " @@ -826,10 +837,11 @@ def predict_entry_point(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=device, verbose=args.verbose, - verbose_preprocessing=False) + verbose_preprocessing=args.verbose, + allow_tqdm=not args.disable_progress_bar) predictor.initialize_from_trained_model_folder( model_folder, args.f, @@ -849,7 +861,7 @@ def predict_entry_point(): # args.step_size, # use_gaussian=True, # use_mirroring=not args.disable_tta, - # perform_everything_on_gpu=True, + # perform_everything_on_device=True, # verbose=args.verbose, # save_probabilities=args.save_probabilities, # overwrite=not args.continue_prediction, @@ -865,19 +877,20 @@ def predict_entry_point(): if __name__ == '__main__': # predict a bunch of files from nnunetv2.paths import nnUNet_results, nnUNet_raw + predictor = nnUNetPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, allow_tqdm=True - ) + ) predictor.initialize_from_trained_model_folder( join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), - use_folds=(0, ), + use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), @@ -888,18 +901,18 @@ def predict_entry_point(): # predict a numpy array from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) ret = predictor.predict_single_npy_array(img, props, None, None, False) iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1) ret = predictor.predict_from_data_iterator(iterator, False, 1) - # predictor = nnUNetPredictor( # tile_step_size=0.5, # use_gaussian=True, # use_mirroring=True, - # perform_everything_on_gpu=True, + # perform_everything_on_device=True, # device=torch.device('cuda', 0), # verbose=False, # allow_tqdm=True @@ -915,4 +928,3 @@ def predict_entry_point(): # num_processes_preprocessing=2, num_processes_segmentation_export=2, # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres', # num_parts=1, part_id=0) - diff --git a/nnunetv2/inference/readme.md b/nnunetv2/inference/readme.md index 7219528..4f832a1 100644 --- a/nnunetv2/inference/readme.md +++ b/nnunetv2/inference/readme.md @@ -57,7 +57,7 @@ Example: tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, diff --git a/nnunetv2/postprocessing/remove_connected_components.py b/nnunetv2/postprocessing/remove_connected_components.py index df29932..a46e8d8 100644 --- a/nnunetv2/postprocessing/remove_connected_components.py +++ b/nnunetv2/postprocessing/remove_connected_components.py @@ -229,12 +229,12 @@ def determine_postprocessing(folder_predictions: str, 'postprocessing_fns': [i.__name__ for i in pp_fns], 'postprocessing_kwargs': pp_fn_kwargs, } - # json is a very annoying little bi###. Can't handle tuples as dict keys. + # json is very annoying. Can't handle tuples as dict keys. tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in tmp['input_folder']['mean'].keys()} tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in tmp['postprocessed']['mean'].keys()} - # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" You retarded bro? + # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" recursive_fix_for_json_export(tmp) save_json(tmp, join(folder_predictions, 'postprocessing.json')) diff --git a/nnunetv2/preprocessing/normalization/default_normalization_schemes.py b/nnunetv2/preprocessing/normalization/default_normalization_schemes.py index 3c90a91..705d477 100644 --- a/nnunetv2/preprocessing/normalization/default_normalization_schemes.py +++ b/nnunetv2/preprocessing/normalization/default_normalization_schemes.py @@ -32,7 +32,7 @@ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by default. """ - image = image.astype(self.target_dtype) + image = image.astype(self.target_dtype, copy=False) if self.use_mask_for_norm is not None and self.use_mask_for_norm: # negative values in the segmentation encode the 'outside' region (think zero values around the brain as # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image. @@ -45,7 +45,8 @@ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: else: mean = image.mean() std = image.std() - image = (image - mean) / (max(std, 1e-8)) + image -= mean + image /= (max(std, 1e-8)) return image @@ -54,13 +55,15 @@ class CTNormalization(ImageNormalization): def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: assert self.intensityproperties is not None, "CTNormalization requires intensity properties" - image = image.astype(self.target_dtype) mean_intensity = self.intensityproperties['mean'] std_intensity = self.intensityproperties['std'] lower_bound = self.intensityproperties['percentile_00_5'] upper_bound = self.intensityproperties['percentile_99_5'] - image = np.clip(image, lower_bound, upper_bound) - image = (image - mean_intensity) / max(std_intensity, 1e-8) + + image = image.astype(self.target_dtype, copy=False) + np.clip(image, lower_bound, upper_bound, out=image) + image -= mean_intensity + image /= max(std_intensity, 1e-8) return image @@ -68,16 +71,16 @@ class NoNormalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: - return image.astype(self.target_dtype) + return image.astype(self.target_dtype, copy=False) class RescaleTo01Normalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: - image = image.astype(self.target_dtype) - image = image - image.min() - image = image / np.clip(image.max(), a_min=1e-8, a_max=None) + image = image.astype(self.target_dtype, copy=False) + image -= image.min() + image /= np.clip(image.max(), a_min=1e-8, a_max=None) return image @@ -89,7 +92,7 @@ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: "Your images do not seem to be RGB images" assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \ ". Your images do not seem to be RGB images" - image = image.astype(self.target_dtype) - image = image / 255. + image = image.astype(self.target_dtype, copy=False) + image /= 255. return image diff --git a/nnunetv2/preprocessing/resampling/default_resampling.py b/nnunetv2/preprocessing/resampling/default_resampling.py index e83f614..e23e14d 100644 --- a/nnunetv2/preprocessing/resampling/default_resampling.py +++ b/nnunetv2/preprocessing/resampling/default_resampling.py @@ -83,7 +83,7 @@ def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray], force_separate_z: Union[bool, None] = False, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): """ - needed for segmentation export. Stupid, I know. Maybe we can fix that with Leos new resampling functions + needed for segmentation export. Stupid, I know """ if isinstance(data, torch.Tensor): data = data.cpu().numpy() diff --git a/nnunetv2/training/dataloading/utils.py b/nnunetv2/training/dataloading/utils.py index bd145b4..352d182 100644 --- a/nnunetv2/training/dataloading/utils.py +++ b/nnunetv2/training/dataloading/utils.py @@ -1,13 +1,93 @@ +from __future__ import annotations import multiprocessing import os -from multiprocessing import Pool from typing import List +from pathlib import Path +from warnings import warn import numpy as np from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles from nnunetv2.configuration import default_num_processes +def find_broken_image_and_labels( + path_to_data_dir: str | Path, +) -> tuple[set[str], set[str]]: + """ + Iterates through all numpys and tries to read them once to see if a ValueError is raised. + If so, the case id is added to the respective set and returned for potential fixing. + + :path_to_data_dir: Path/str to the preprocessed directory containing the npys and npzs. + :returns: Tuple of a set containing the case ids of the broken npy images and a set of the case ids of broken npy segmentations. + """ + content = os.listdir(path_to_data_dir) + unique_ids = [c[:-4] for c in content if c.endswith(".npz")] + failed_data_ids = set() + failed_seg_ids = set() + for unique_id in unique_ids: + # Try reading data + try: + np.load(path_to_data_dir / (unique_id + ".npy"), "r") + except ValueError: + failed_data_ids.add(unique_id) + # Try reading seg + try: + np.load(path_to_data_dir / (unique_id + "_seg.npy"), "r") + except ValueError: + failed_seg_ids.add(unique_id) + + return failed_data_ids, failed_seg_ids + + +def try_fix_broken_npy(path_do_data_dir: Path, case_ids: set[str], fix_image: bool): + """ + Receives broken case ids and tries to fix them by re-extracting the npz file (up to 5 times). + + :param case_ids: Set of case ids that are broken. + :param path_do_data_dir: Path to the preprocessed directory containing the npys and npzs. + :raises ValueError: If the npy file could not be unpacked after 5 tries. -- + """ + for case_id in case_ids: + for i in range(5): + try: + key = "data" if fix_image else "seg" + suffix = ".npy" if fix_image else "_seg.npy" + read_npz = np.load(path_do_data_dir / (case_id + ".npz"), "r")[key] + np.save(path_do_data_dir / (case_id + suffix), read_npz) + # Try loading the just saved image. + np.load(path_do_data_dir / (case_id + suffix), "r") + break + except ValueError: + if i == 4: + raise ValueError( + f"Could not unpack {case_id + suffix} after 5 tries!" + ) + continue + + +def verify_or_stratify_npys(path_to_data_dir: str | Path) -> None: + """ + This re-reads the npy files after unpacking. Should there be a loading issue with any, it will try to unpack this file again and overwrites the existing. + If the new file does not get saved correctly 5 times, it will raise an error with the file name to the user. Does the same for images and segmentations. + :param path_to_data_dir: Path to the preprocessed directory containing the npys and npzs. + :raises ValueError: If the npy file could not be unpacked after 5 tries. -- + Otherwise an obscured error will be raised later during training (depending when the broken file is sampled) + """ + path_to_data_dir = Path(path_to_data_dir) + # Check for broken image and segmentation npys + failed_data_ids, failed_seg_ids = find_broken_image_and_labels(path_to_data_dir) + + if len(failed_data_ids) != 0 or len(failed_seg_ids) != 0: + warn( + f"Found {len(failed_data_ids)} faulty data npys and {len(failed_seg_ids)}!\n" + + f"Faulty images: {failed_data_ids}; Faulty segmentations: {failed_seg_ids})\n" + + "Trying to fix them now." + ) + # Try to fix the broken npys by reextracting the npz. If that fails, raise error + try_fix_broken_npy(path_to_data_dir, failed_data_ids, fix_image=True) + try_fix_broken_npy(path_to_data_dir, failed_seg_ids, fix_image=False) + + def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False) -> None: try: a = np.load(npz_file) # inexpensive, no compression is done here. This just reads metadata diff --git a/nnunetv2/training/loss/compound_losses.py b/nnunetv2/training/loss/compound_losses.py index 9db0a42..eaeb5d8 100644 --- a/nnunetv2/training/loss/compound_losses.py +++ b/nnunetv2/training/loss/compound_losses.py @@ -38,11 +38,10 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor): if self.ignore_label is not None: assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ '(DC_and_CE_loss)' - mask = (target != self.ignore_label).bool() + mask = target != self.ignore_label # remove ignore label from target, replace with one of the known labels. It doesn't matter because we # ignore gradients in those areas anyway - target_dice = torch.clone(target) - target_dice[target == self.ignore_label] = 0 + target_dice = torch.where(mask, target, 0) num_fg = mask.sum() else: target_dice = target @@ -50,7 +49,7 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor): dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ if self.weight_dice != 0 else 0 - ce_loss = self.ce(net_output, target[:, 0].long()) \ + ce_loss = self.ce(net_output, target[:, 0]) \ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 result = self.weight_ce * ce_loss + self.weight_dice * dc_loss diff --git a/nnunetv2/training/loss/deep_supervision.py b/nnunetv2/training/loss/deep_supervision.py index 03141e8..952e3f7 100644 --- a/nnunetv2/training/loss/deep_supervision.py +++ b/nnunetv2/training/loss/deep_supervision.py @@ -1,3 +1,4 @@ +import torch from torch import nn @@ -11,25 +12,19 @@ def __init__(self, loss, weight_factors=None): If weights are None, all w will be 1. """ super(DeepSupervisionWrapper, self).__init__() - self.weight_factors = weight_factors + assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" + self.weight_factors = tuple(weight_factors) self.loss = loss def forward(self, *args): - for i in args: - assert isinstance(i, (tuple, list)), f"all args must be either tuple or list, got {type(i)}" - # we could check for equal lengths here as well but we really shouldn't overdo it with checks because - # this code is executed a lot of times! + assert all([isinstance(i, (tuple, list)) for i in args]), \ + f"all args must be either tuple or list, got {[type(i) for i in args]}" + # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because + # this code is executed a lot of times! if self.weight_factors is None: - weights = [1] * len(args[0]) + weights = (1, ) * len(args[0]) else: weights = self.weight_factors - # we initialize the loss like this instead of 0 to ensure it sits on the correct device, not sure if that's - # really necessary - l = weights[0] * self.loss(*[j[0] for j in args]) - for i, inputs in enumerate(zip(*args)): - if i == 0: - continue - l += weights[i] * self.loss(*inputs) - return l \ No newline at end of file + return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) diff --git a/nnunetv2/training/loss/dice.py b/nnunetv2/training/loss/dice.py index af55490..5744357 100644 --- a/nnunetv2/training/loss/dice.py +++ b/nnunetv2/training/loss/dice.py @@ -74,18 +74,18 @@ def forward(self, x, y, loss_mask=None): x = self.apply_nonlin(x) # make everything shape (b, c) - axes = list(range(2, len(x.shape))) + axes = tuple(range(2, x.ndim)) + with torch.no_grad(): - if len(x.shape) != len(y.shape): + if x.ndim != y.ndim: y = y.view((y.shape[0], 1, *y.shape[1:])) if x.shape == y.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = y else: - gt = y.long() y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool) - y_onehot.scatter_(1, gt, 1) + y_onehot.scatter_(1, y.long(), 1) if not self.do_bg: y_onehot = y_onehot[:, 1:] @@ -96,15 +96,19 @@ def forward(self, x, y, loss_mask=None): if not self.do_bg: x = x[:, 1:] - intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes) - sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes) - - if self.ddp and self.batch_dice: - intersect = AllGatherGrad.apply(intersect).sum(0) - sum_pred = AllGatherGrad.apply(sum_pred).sum(0) - sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + if loss_mask is None: + intersect = (x * y_onehot).sum(axes) + sum_pred = x.sum(axes) + else: + intersect = (x * y_onehot * loss_mask).sum(axes) + sum_pred = (x * loss_mask).sum(axes) if self.batch_dice: + if self.ddp: + intersect = AllGatherGrad.apply(intersect).sum(0) + sum_pred = AllGatherGrad.apply(sum_pred).sum(0) + sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + intersect = intersect.sum(0) sum_pred = sum_pred.sum(0) sum_gt = sum_gt.sum(0) @@ -128,22 +132,18 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): :return: """ if axes is None: - axes = tuple(range(2, len(net_output.size()))) - - shp_x = net_output.shape - shp_y = gt.shape + axes = tuple(range(2, net_output.ndim)) with torch.no_grad(): - if len(shp_x) != len(shp_y): - gt = gt.view((shp_y[0], 1, *shp_y[1:])) + if net_output.ndim != gt.ndim: + gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) if net_output.shape == gt.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: - gt = gt.long() - y_onehot = torch.zeros(shp_x, device=net_output.device) - y_onehot.scatter_(1, gt, 1) + y_onehot = torch.zeros(net_output.shape, device=net_output.device) + y_onehot.scatter_(1, gt.long(), 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) @@ -152,7 +152,7 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): if mask is not None: with torch.no_grad(): - mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))])) + mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)])) tp *= mask_here fp *= mask_here fn *= mask_here diff --git a/nnunetv2/training/loss/robust_ce_loss.py b/nnunetv2/training/loss/robust_ce_loss.py index ad46659..3399e3a 100644 --- a/nnunetv2/training/loss/robust_ce_loss.py +++ b/nnunetv2/training/loss/robust_ce_loss.py @@ -10,7 +10,7 @@ class RobustCrossEntropyLoss(nn.CrossEntropyLoss): input must be logits, not probabilities! """ def forward(self, input: Tensor, target: Tensor) -> Tensor: - if len(target.shape) == len(input.shape): + if target.ndim == input.ndim: assert target.shape[1] == 1 target = target[:, 0] return super().forward(input, target.long()) @@ -30,4 +30,3 @@ def forward(self, inp, target): num_voxels = np.prod(res.shape, dtype=np.int64) res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) return res.mean() - diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 8d9101f..6c368a5 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -11,6 +11,8 @@ import numpy as np import torch +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ @@ -52,13 +54,13 @@ from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.crossval_split import generate_crossval_split from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager -from sklearn.model_selection import KFold from torch import autocast, nn from torch import distributed as dist from torch.cuda import device_count @@ -148,6 +150,7 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic self.num_val_iterations_per_epoch = 50 self.num_epochs = 1000 self.current_epoch = 0 + self.enable_deep_supervision = True ### Dealing with labels/regions self.label_manager = self.plans_manager.get_label_manager(dataset_json) @@ -155,7 +158,7 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic # needed for predictions. We do sigmoid in case of (overlapping) regions self.num_input_channels = None # -> self.initialize() - self.network = None # -> self._get_network() + self.network = None # -> self.build_network_architecture() self.optimizer = self.lr_scheduler = None # -> self.initialize self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None self.loss = None # -> self.initialize @@ -203,13 +206,17 @@ def initialize(self): self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json) - self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, - self.configuration_manager, - self.num_input_channels, - enable_deep_supervision=True).to(self.device) + self.network = self.build_network_architecture( + self.configuration_manager.network_arch_class_name, + self.configuration_manager.network_arch_init_kwargs, + self.configuration_manager.network_arch_init_kwargs_req_import, + self.num_input_channels, + self.label_manager.num_segmentation_heads, + self.enable_deep_supervision + ).to(self.device) # compile network for free speedup if self._do_i_compile(): - self.print_to_log_file('Compiling network...') + self.print_to_log_file('Using torch.compile...') self.network = torch.compile(self.network) self.optimizer, self.lr_scheduler = self.configure_optimizers() @@ -263,13 +270,14 @@ def _save_debug_information(self): save_json(dct, join(self.output_folder, "debug.json")) @staticmethod - def build_network_architecture(plans_manager: PlansManager, - dataset_json, - configuration_manager: ConfigurationManager, - num_input_channels, + def build_network_architecture(architecture_class_name: str, + arch_init_kwargs: dict, + arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], + num_input_channels: int, + num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: """ - his is where you build the architecture according to the plans. There is no obligation to use + This is where you build the architecture according to the plans. There is no obligation to use get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what you want. Even ignore the plans and just return something static (as long as it can process the requested patch size) @@ -287,12 +295,21 @@ def build_network_architecture(plans_manager: PlansManager, should be generated. label_manager takes care of all that for you.) """ - return get_network_from_plans(plans_manager, dataset_json, configuration_manager, - num_input_channels, deep_supervision=enable_deep_supervision) + return get_network_from_plans( + architecture_class_name, + arch_init_kwargs, + arch_init_kwargs_req_import, + num_input_channels, + num_output_channels, + allow_init=True, + deep_supervision=enable_deep_supervision) def _get_deep_supervision_scales(self): - deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( - self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + if self.enable_deep_supervision: + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + else: + deep_supervision_scales = None # for train and val_transforms return deep_supervision_scales def _set_batch_size_and_oversample(self): @@ -301,8 +318,6 @@ def _set_batch_size_and_oversample(self): self.batch_size = self.configuration_manager.batch_size else: # batch size is distributed over DDP workers and we need to change oversample_percent for each worker - batch_sizes = [] - oversample_percents = [] world_size = dist.get_world_size() my_rank = dist.get_rank() @@ -311,36 +326,38 @@ def _set_batch_size_and_oversample(self): assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ 'GPUs... Duh.' - batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) - - for rank in range(world_size): - if (rank + 1) * batch_size_per_GPU > global_batch_size: - batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) - else: - batch_size = batch_size_per_GPU - - batch_sizes.append(batch_size) - - sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) - sample_id_high = np.sum(batch_sizes) - - if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): - oversample_percents.append(0.0) - elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): - oversample_percents.append(1.0) - else: - percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size - oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - - sample_id_low / global_batch_size) / percent_covered_by_this_rank) - oversample_percents.append(oversample_percent_here) + batch_size_per_GPU = [global_batch_size // world_size] * world_size + batch_size_per_GPU = [batch_size_per_GPU[i] + 1 + if (batch_size_per_GPU[i] * world_size + i) < global_batch_size + else batch_size_per_GPU[i] + for i in range(len(batch_size_per_GPU))] + assert sum(batch_size_per_GPU) == global_batch_size + + sample_id_low = 0 if my_rank == 0 else np.sum(batch_size_per_GPU[:my_rank]) + sample_id_high = np.sum(batch_size_per_GPU[:my_rank + 1]) + + # This is how oversampling is determined in DataLoader + # round(self.batch_size * (1 - self.oversample_foreground_percent)) + # We need to use the same scheme here because an oversample of 0.33 with a batch size of 2 will be rounded + # to an oversample of 0.5 (1 sample random, one oversampled). This may get lost if we just numerically + # compute oversample + oversample = [True if not i < round(global_batch_size * (1 - self.oversample_foreground_percent)) else False + for i in range(global_batch_size)] + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percent = 0.0 + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percent = 1.0 + else: + oversample_percent = sum(oversample[sample_id_low:sample_id_high]) / batch_size_per_GPU[my_rank] - print("worker", my_rank, "oversample", oversample_percents[my_rank]) - print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + print("worker", my_rank, "oversample", oversample_percent) + print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank]) # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) - self.batch_size = batch_sizes[my_rank] - self.oversample_foreground_percent = oversample_percents[my_rank] + self.batch_size = batch_size_per_GPU[my_rank] + self.oversample_foreground_percent = oversample_percent def _build_loss(self): if self.label_manager.has_regions: @@ -354,17 +371,24 @@ def _build_loss(self): 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) - deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + if self.is_ddp and not self._do_i_compile(): + # very strange and stupid interaction. DDP crashes and complains about unused parameters due to + # weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff. + # Anywho, the simple fix is to set a very low weight to this. + weights[-1] = 1e-6 + else: + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @@ -508,9 +532,9 @@ def plot_network_architecture(self): def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, - so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + so always the same) and save it as splits_final.json file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own - splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + splits_final.json file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. @@ -529,15 +553,8 @@ def do_split(self): # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") - splits = [] - all_keys_sorted = np.sort(list(dataset.keys())) - kfold = KFold(n_splits=5, shuffle=True, random_state=12345) - for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): - train_keys = np.array(all_keys_sorted)[train_idx] - test_keys = np.array(all_keys_sorted)[test_idx] - splits.append({}) - splits[-1]['train'] = list(train_keys) - splits[-1]['val'] = list(test_keys) + all_keys_sorted = list(np.sort(list(dataset.keys()))) + splits = generate_crossval_split(all_keys_sorted, seed=12345, n_splits=5) save_json(splits, splits_file) else: @@ -591,10 +608,15 @@ def get_dataloaders(self): # needed for deep supervision: how much do we need to downscale the segmentation targets for the different # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() - rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ - self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + ( + rotation_for_DA, + do_dummy_2d_data_aug, + initial_patch_size, + mirror_axes, + ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() # training pipeline tr_transforms = self.get_training_transforms( @@ -663,19 +685,21 @@ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): return dl_tr, dl_val @staticmethod - def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], - rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple], - mirror_axes: Tuple[int, ...], - do_dummy_2d_data_aug: bool, - order_resampling_data: int = 3, - order_resampling_seg: int = 1, - border_val_seg: int = -1, - use_mask_for_norm: List[bool] = None, - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None) -> AbstractTransform: + def get_training_transforms( + patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> AbstractTransform: tr_transforms = [] if do_dummy_2d_data_aug: ignore_axes = (0,) @@ -755,11 +779,13 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], return tr_transforms @staticmethod - def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None) -> AbstractTransform: + def get_validation_transforms( + deep_supervision_scales: Union[List, Tuple, None], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> AbstractTransform: val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) @@ -788,9 +814,13 @@ def set_deep_supervision_enabled(self, enabled: bool): chances you need to change this as well! """ if self.is_ddp: - self.network.module.decoder.deep_supervision = enabled + mod = self.network.module else: - self.network.decoder.deep_supervision = enabled + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + mod.decoder.deep_supervision = enabled def on_train_start(self): if not self.was_initialized: @@ -799,7 +829,7 @@ def on_train_start(self): maybe_mkdir_p(self.output_folder) # make sure deep supervision is on in the network - self.set_deep_supervision_enabled(True) + self.set_deep_supervision_enabled(self.enable_deep_supervision) self.print_plans() empty_cache(self.device) @@ -849,9 +879,11 @@ def on_train_end(self): old_stdout = sys.stdout with open(os.devnull, 'w') as f: sys.stdout = f - if self.dataloader_train is not None: + if self.dataloader_train is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_train._finish() - if self.dataloader_val is not None: + if self.dataloader_val is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_val._finish() sys.stdout = old_stdout @@ -879,7 +911,7 @@ def train_step(self, batch: dict) -> dict: target = target.to(self.device, non_blocking=True) self.optimizer.zero_grad(set_to_none=True) - # Autocast is a little bitch. + # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. @@ -925,7 +957,7 @@ def validation_step(self, batch: dict) -> dict: else: target = target.to(self.device, non_blocking=True) - # Autocast is a little bitch. + # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. @@ -934,9 +966,10 @@ def validation_step(self, batch: dict) -> dict: del data l = self.loss(output, target) - # we only need the output with the highest output resolution - output = output[0] - target = target[0] + # we only need the output with the highest output resolution (if DS enabled) + if self.enable_deep_supervision: + output = output[0] + target = target[0] # the following is needed for online evaluation. Fake dice (green line) axes = [0] + list(range(2, output.ndim)) @@ -1005,8 +1038,7 @@ def on_validation_epoch_end(self, val_outputs: List[dict]): else: loss_here = np.mean(outputs_collated['loss']) - global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in - zip(tp, fp, fn)]] + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)]] mean_fg_dice = np.nanmean(global_dc_per_class) self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) @@ -1018,7 +1050,6 @@ def on_epoch_start(self): def on_epoch_end(self): self.logger.log('epoch_end_timestamps', time(), self.current_epoch) - # todo find a solution for this stupid shit self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in @@ -1109,8 +1140,18 @@ def perform_actual_validation(self, save_probabilities: bool = False): self.set_deep_supervision_enabled(False) self.network.eval() + if self.is_ddp and self.batch_size == 1 and self.enable_deep_supervision and self._do_i_compile(): + self.print_to_log_file("WARNING! batch size is 1 during training and torch.compile is enabled. If you " + "encounter crashes in validation then this is because torch.compile forgets " + "to trigger a recompilation of the model with deep supervision disabled. " + "This causes torch.flip to complain about getting a tuple as input. Just rerun the " + "validation with --val (exactly the same as before) and then it will work. " + "Why? Because --val triggers nnU-Net to ONLY run validation meaning that the first " + "forward pass (where compile is triggered) already has deep supervision disabled. " + "This is exactly what we need in perform_actual_validation") + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, device=self.device, verbose=False, + perform_everything_on_device=True, device=self.device, verbose=False, verbose_preprocessing=False, allow_tqdm=False) predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, self.dataset_json, self.__class__.__name__, @@ -1125,7 +1166,11 @@ def perform_actual_validation(self, save_probabilities: bool = False): # the validation keys across the workers. _, val_keys = self.do_split() if self.is_ddp: + last_barrier_at_idx = len(val_keys) // dist.get_world_size() - 1 + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + # we cannot just have barriers all over the place because the number of keys each GPU receives can be + # different dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, @@ -1138,13 +1183,13 @@ def perform_actual_validation(self, save_probabilities: bool = False): results = [] - for k in dataset_val.keys(): + for i, k in enumerate(dataset_val.keys()): proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, - allowed_num_queued=2) + allowed_num_queued=2) while not proceed: sleep(0.1) proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, - allowed_num_queued=2) + allowed_num_queued=2) self.print_to_log_file(f"predicting {k}") data, seg, properties = dataset_val.load_case(k) @@ -1157,15 +1202,10 @@ def perform_actual_validation(self, save_probabilities: bool = False): warnings.simplefilter("ignore") data = torch.from_numpy(data) + self.print_to_log_file(f'{k}, shape {data.shape}, rank {self.local_rank}') output_filename_truncated = join(validation_output_folder, k) - try: - prediction = predictor.predict_sliding_window_return_logits(data) - except RuntimeError: - predictor.perform_everything_on_gpu = False - prediction = predictor.predict_sliding_window_return_logits(data) - predictor.perform_everything_on_gpu = True - + prediction = predictor.predict_sliding_window_return_logits(data) prediction = prediction.cpu() # this needs to go into background processes @@ -1213,6 +1253,9 @@ def perform_actual_validation(self, save_probabilities: bool = False): self.dataset_json), ) )) + # if we don't barrier from time to time we will get nccl timeouts for large datasets. Yuck. + if self.is_ddp and i < last_barrier_at_idx and (i + 1) % 20 == 0: + dist.barrier() _ = [r.get() for r in results] @@ -1227,9 +1270,12 @@ def perform_actual_validation(self, save_probabilities: bool = False): self.dataset_json["file_ending"], self.label_manager.foreground_regions if self.label_manager.has_regions else self.label_manager.foreground_labels, - self.label_manager.ignore_label, chill=True) + self.label_manager.ignore_label, chill=True, + num_processes=default_num_processes * dist.get_world_size() if + self.is_ddp else default_num_processes) self.print_to_log_file("Validation complete", also_print_to_console=True) - self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), + also_print_to_console=True) self.set_deep_supervision_enabled(True) compute_gaussian.cache_clear() diff --git a/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py b/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py index 6c12ecc..e7de92c 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py +++ b/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py @@ -1,25 +1,39 @@ import torch -from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import \ - nnUNetTrainerBenchmark_5epochs +from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import ( + nnUNetTrainerBenchmark_5epochs, +) from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): - def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self._set_batch_size_and_oversample() - num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, - self.dataset_json) + num_input_channels = determine_num_input_channels( + self.plans_manager, self.configuration_manager, self.dataset_json + ) patch_size = self.configuration_manager.patch_size dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) - dummy_target = [ - torch.round( - torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) * - max(self.label_manager.all_labels) - ) for k in self._get_deep_supervision_scales()] - self.dummy_batch = {'data': dummy_data, 'target': dummy_target} + if self.enable_deep_supervision: + dummy_target = [ + torch.round( + torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) + * max(self.label_manager.all_labels) + ) + for k in self._get_deep_supervision_scales() + ] + else: + raise NotImplementedError("This trainer does not support deep supervision") + self.dummy_batch = {"data": dummy_data, "target": dummy_target} def get_dataloaders(self): return None, None diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py index bd9c31c..a96cb2b 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py @@ -34,9 +34,6 @@ class nnUNetTrainerDA5(nnUNetTrainer): def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): - """ - This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. - """ patch_size = self.configuration_manager.patch_size dim = len(patch_size) # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) @@ -93,7 +90,7 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @staticmethod def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple], + deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, order_resampling_data: int = 3, @@ -233,9 +230,9 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append( BrightnessGradientAdditiveTransform( - lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), + _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), - max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5), + max_strength=_brightness_gradient_additive_max_strength, mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, @@ -245,9 +242,9 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append( LocalGammaTransform( - lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), + _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), - lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4), + _local_gamma_gamma, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 @@ -354,6 +351,18 @@ def get_dataloaders(self): return mt_gen_train, mt_gen_val +def _brightnessadditive_localgamma_transform_scale(x, y): + return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))) + + +def _brightness_gradient_additive_max_strength(_x, _y): + return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5) + + +def _local_gamma_gamma(): + return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4) + + class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5): def get_dataloaders(self): """ diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py index e87ff8f..be31857 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py @@ -102,3 +102,56 @@ def get_dataloaders(self): max(1, allowed_num_processes // 2), 3, None, True, 0.02) return mt_gen_train, mt_gen_val + + +class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # Deactivate mirroring data augmentation + mirror_axes = None + self.inference_allowed_mirroring_axes = None + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py index 527e262..17f3586 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py @@ -10,7 +10,7 @@ class nnUNetTrainerNoDA(nnUNetTrainer): @staticmethod def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple], + deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, order_resampling_data: int = 1, diff --git a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py index c8432df..fdc0fea 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py +++ b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py @@ -7,27 +7,35 @@ class nnUNetTrainerCELoss(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = RobustCrossEntropyLoss(weight=None, - ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100) - - deep_supervision_scales = self._get_deep_supervision_scales() + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = RobustCrossEntropyLoss( + weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100 + ) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss): - def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): """used for debugging plans etc""" super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self.num_epochs = 5 diff --git a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py index 6f0b7c0..b139286 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py +++ b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py @@ -14,17 +14,18 @@ def _build_loss(self): 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) - deep_supervision_scales = self._get_deep_supervision_scales() - - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 - - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss @@ -43,16 +44,17 @@ def _build_loss(self): ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) - deep_supervision_scales = self._get_deep_supervision_scales() + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss diff --git a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py index afb3fe1..5eff10e 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py +++ b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py @@ -7,63 +7,70 @@ class nnUNetTrainerTopk10Loss(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, - k=10) + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = TopKLoss( + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10 + ) - deep_supervision_scales = self._get_deep_supervision_scales() + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerTopk10LossLS01(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, - k=10, label_smoothing=0.1) + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = TopKLoss( + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, + k=10, + label_smoothing=0.1, + ) - deep_supervision_scales = self._get_deep_supervision_scales() + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = DC_and_topk_loss({'batch_dice': self.configuration_manager.batch_dice, - 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, - {'k': 10, - 'label_smoothing': 0.0}, - weight_ce=1, weight_dice=1, - ignore_label=self.label_manager.ignore_label) + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = DC_and_topk_loss( + {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp}, + {"k": 10, "label_smoothing": 0.0}, + weight_ce=1, + weight_dice=1, + ignore_label=self.label_manager.ignore_label, + ) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - deep_supervision_scales = self._get_deep_supervision_scales() + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 - - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss diff --git a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py index 5f6190c..50d0c9f 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py +++ b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py @@ -1,73 +1,32 @@ -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He -from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from typing import Union, Tuple, List +from dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm from torch import nn +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + class nnUNetTrainerBN(nnUNetTrainer): @staticmethod - def build_network_architecture(plans_manager: PlansManager, - dataset_json, - configuration_manager: ConfigurationManager, - num_input_channels, + def build_network_architecture(architecture_class_name: str, + arch_init_kwargs: dict, + arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], + num_input_channels: int, + num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: - num_stages = len(configuration_manager.conv_kernel_sizes) - dim = len(configuration_manager.conv_kernel_sizes[0]) - conv_op = convert_dim_to_conv_op(dim) + if 'norm_op' not in arch_init_kwargs.keys(): + raise RuntimeError("'norm_op' not found in arch_init_kwargs. This does not look like an architecture " + "I can hack BN into. This trainer only works with default nnU-Net architectures.") - label_manager = plans_manager.get_label_manager(dataset_json) + from pydoc import locate + conv_op = locate(arch_init_kwargs['conv_op']) + bn_class = get_matching_batchnorm(conv_op) + arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__ + arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True} - segmentation_network_class_name = configuration_manager.UNet_class_name - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accommodate that.' - network_class = mapping[segmentation_network_class_name] + return nnUNetTrainer.build_network_architecture(architecture_class_name, + arch_init_kwargs, + arch_init_kwargs_req_import, + num_input_channels, + num_output_channels, enable_deep_supervision) - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, - 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder - } - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, - configuration_manager.unet_max_num_features) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=configuration_manager.conv_kernel_sizes, - strides=configuration_manager.pool_op_kernel_sizes, - num_classes=label_manager.num_segmentation_heads, - deep_supervision=enable_deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] - ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - return model diff --git a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py index 34f9b55..1152fbe 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py +++ b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py @@ -1,114 +1,16 @@ -import torch -from torch import autocast - -from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss -from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.helpers import dummy_context -from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels -from torch.nn.parallel import DistributedDataParallel as DDP +import torch class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): - def _build_loss(self): - if self.label_manager.has_regions: - loss = DC_and_BCE_loss({}, - {'batch_dice': self.configuration_manager.batch_dice, - 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, - use_ignore_label=self.label_manager.ignore_label is not None, - dice_class=MemoryEfficientSoftDiceLoss) - else: - loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, - 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, - ignore_label=self.label_manager.ignore_label, - dice_class=MemoryEfficientSoftDiceLoss) - return loss - - def _get_deep_supervision_scales(self): - return None - - def initialize(self): - if not self.was_initialized: - self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, - self.dataset_json) - - self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, - self.configuration_manager, - self.num_input_channels, - enable_deep_supervision=False).to(self.device) - - self.optimizer, self.lr_scheduler = self.configure_optimizers() - # if ddp, wrap in DDP wrapper - if self.is_ddp: - self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) - self.network = DDP(self.network, device_ids=[self.local_rank]) - - self.loss = self._build_loss() - self.was_initialized = True - else: - raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " - "That should not happen.") - - def set_deep_supervision_enabled(self, enabled: bool): - pass - - def validation_step(self, batch: dict) -> dict: - data = batch['data'] - target = batch['target'] - - data = data.to(self.device, non_blocking=True) - if isinstance(target, list): - target = [i.to(self.device, non_blocking=True) for i in target] - else: - target = target.to(self.device, non_blocking=True) - - self.optimizer.zero_grad(set_to_none=True) - - # Autocast is a little bitch. - # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. - # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) - # So autocast will only be active if we have a cuda device. - with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): - output = self.network(data) - del data - l = self.loss(output, target) - - # the following is needed for online evaluation. Fake dice (green line) - axes = [0] + list(range(2, output.ndim)) - - if self.label_manager.has_regions: - predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() - else: - # no need for softmax - output_seg = output.argmax(1)[:, None] - predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) - predicted_segmentation_onehot.scatter_(1, output_seg, 1) - del output_seg - - if self.label_manager.has_ignore_label: - if not self.label_manager.has_regions: - mask = (target != self.label_manager.ignore_label).float() - # CAREFUL that you don't rely on target after this line! - target[target == self.label_manager.ignore_label] = 0 - else: - mask = 1 - target[:, -1:] - # CAREFUL that you don't rely on target after this line! - target = target[:, :-1] - else: - mask = None - - tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) - - tp_hard = tp.detach().cpu().numpy() - fp_hard = fp.detach().cpu().numpy() - fn_hard = fn.detach().cpu().numpy() - if not self.label_manager.has_regions: - # if we train with regions all segmentation heads predict some kind of foreground. In conventional - # (softmax training) there needs tobe one output for the background. We are not interested in the - # background Dice - # [1:] in order to remove background - tp_hard = tp_hard[1:] - fp_hard = fp_hard[1:] - fn_hard = fn_hard[1:] - - return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} \ No newline at end of file + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.enable_deep_supervision = False diff --git a/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py b/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py index 89fef48..467a6fd 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py +++ b/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Tuple import torch @@ -59,6 +60,13 @@ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) return dl_tr, dl_val + def _set_batch_size_and_oversample(self): + old_oversample = deepcopy(self.oversample_foreground_percent) + super()._set_batch_size_and_oversample() + self.oversample_foreground_percent = old_oversample + self.print_to_log_file(f"Ignore previous message about oversample_foreground_percent. " + f"oversample_foreground_percent overwritten to {self.oversample_foreground_percent}") + class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, diff --git a/nnunetv2/utilities/crossval_split.py b/nnunetv2/utilities/crossval_split.py new file mode 100644 index 0000000..472603b --- /dev/null +++ b/nnunetv2/utilities/crossval_split.py @@ -0,0 +1,16 @@ +from typing import List + +import numpy as np +from sklearn.model_selection import KFold + + +def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]: + splits = [] + kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed) + for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)): + train_keys = np.array(train_identifiers)[train_idx] + test_keys = np.array(train_identifiers)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + return splits diff --git a/nnunetv2/utilities/get_network_from_plans.py b/nnunetv2/utilities/get_network_from_plans.py index 1dd1dd2..be79777 100644 --- a/nnunetv2/utilities/get_network_from_plans.py +++ b/nnunetv2/utilities/get_network_from_plans.py @@ -1,77 +1,43 @@ -from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet -from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 -from nnunetv2.utilities.network_initialization import InitWeights_He -from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager -from torch import nn +import pydoc +import warnings +from typing import Union +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from batchgenerators.utilities.file_and_folder_operations import join -def get_network_from_plans(plans_manager: PlansManager, - dataset_json: dict, - configuration_manager: ConfigurationManager, - num_input_channels: int, - deep_supervision: bool = True): - """ - we may have to change this in the future to accommodate other plans -> network mappings - num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the - trainer rather than inferring it again from the plans here. - """ - num_stages = len(configuration_manager.conv_kernel_sizes) +def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, + allow_init=True, deep_supervision: Union[bool, None] = None): + network_class = arch_class_name + architecture_kwargs = dict(**arch_kwargs) + for ri in arch_kwargs_req_import: + if architecture_kwargs[ri] is not None: + architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri]) - dim = len(configuration_manager.conv_kernel_sizes[0]) - conv_op = convert_dim_to_conv_op(dim) + nw_class = pydoc.locate(network_class) + # sometimes things move around, this makes it so that we can at least recover some of that + if nw_class is None: + warnings.warn(f'Network class {network_class} not found. Attempting to locate it within ' + f'dynamic_network_architectures.architectures...') + import dynamic_network_architectures + nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), + network_class.split(".")[-1], + 'dynamic_network_architectures.architectures') + if nw_class is not None: + print(f'FOUND IT: {nw_class}') + else: + raise ImportError('Network class could not be found, please check/correct your plans file') - label_manager = plans_manager.get_label_manager(dataset_json) + if deep_supervision is not None and 'deep_supervision' not in arch_kwargs.keys(): + arch_kwargs['deep_supervision'] = deep_supervision - segmentation_network_class_name = configuration_manager.UNet_class_name - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accommodate that.' - network_class = mapping[segmentation_network_class_name] - - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, - 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder - } - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, - configuration_manager.unet_max_num_features) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=configuration_manager.conv_kernel_sizes, - strides=configuration_manager.pool_op_kernel_sizes, - num_classes=label_manager.num_segmentation_heads, - deep_supervision=deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] + network = nw_class( + input_channels=input_channels, + num_classes=output_channels, + **architecture_kwargs ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - return model + + if hasattr(network, 'initialize') and allow_init: + network.apply(network.initialize) + + return network diff --git a/nnunetv2/utilities/json_export.py b/nnunetv2/utilities/json_export.py index 5ea463c..d6bcd06 100644 --- a/nnunetv2/utilities/json_export.py +++ b/nnunetv2/utilities/json_export.py @@ -5,7 +5,8 @@ def recursive_fix_for_json_export(my_dict: dict): - # json is stupid. 'cannot serialize object of type bool_/int64/float64'. Come on bro. + # json is ... a very nice thing to have + # 'cannot serialize object of type bool_/int64/float64'. Apart from that of course... keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys.... for k in keys: if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)): @@ -37,7 +38,7 @@ def recursive_fix_for_json_export(my_dict: dict): def fix_types_iterable(iterable, output_type): - # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep you hands off of this. + # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep your hands off of this. out = [] for i in iterable: if type(i) in (np.int64, np.int32, np.int8, np.uint8): diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 6c39fd1..11b76df 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -1,6 +1,7 @@ from __future__ import annotations -import dynamic_network_architectures +import warnings + from copy import deepcopy from functools import lru_cache, partial from typing import Union, Tuple, List, Type, Callable @@ -9,8 +10,6 @@ import torch from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name -from torch import nn - import nnunetv2 from batchgenerators.utilities.file_and_folder_operations import load_json, join @@ -18,9 +17,9 @@ from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans - # see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ from typing import TYPE_CHECKING +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm if TYPE_CHECKING: from nnunetv2.utilities.label_handling.label_handling import LabelManager @@ -33,6 +32,68 @@ class ConfigurationManager(object): def __init__(self, configuration_dict: dict): self.configuration = configuration_dict + # backwards compatibility + if 'architecture' not in self.configuration.keys(): + warnings.warn("Detected old nnU-Net plans format. Attempting to reconstruct network architecture " + "parameters. If this fails, rerun nnUNetv2_plan_experiment for your dataset. If you use a " + "custom architecture, please downgrade nnU-Net to the version you implemented this " + "or update your implementation + plans.") + # try to build the architecture information from old plans, modify configuration dict to match new standard + unet_class_name = self.configuration["UNet_class_name"] + if unet_class_name == "PlainConvUNet": + network_class_name = "dynamic_network_architectures.architectures.unet.PlainConvUNet" + elif unet_class_name == 'ResidualEncoderUNet': + network_class_name = "dynamic_network_architectures.architectures.residual_unet.ResidualEncoderUNet" + else: + raise RuntimeError(f'Unknown architecture {unet_class_name}. This conversion only supports ' + f'PlainConvUNet and ResidualEncoderUNet') + + n_stages = len(self.configuration["n_conv_per_stage_encoder"]) + + dim = len(self.configuration["patch_size"]) + conv_op = convert_dim_to_conv_op(dim) + instnorm = get_matching_instancenorm(dimension=dim) + + arch_dict = { + 'network_class_name': network_class_name, + 'arch_kwargs': { + "n_stages": n_stages, + "features_per_stage": [min(self.configuration["UNet_base_num_features"] * 2 ** i, + self.configuration["unet_max_num_features"]) + for i in range(n_stages)], + "conv_op": conv_op.__module__ + '.' + conv_op.__name__, + "kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]), + "strides": deepcopy(self.configuration["pool_op_kernel_sizes"]), + "n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]), + "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), + "conv_bias": True, + "norm_op": instnorm.__module__ + '.' + instnorm.__name__, + "norm_op_kwargs": { + "eps": 1e-05, + "affine": True + }, + "dropout_op": None, + "dropout_op_kwargs": None, + "nonlin": "torch.nn.LeakyReLU", + "nonlin_kwargs": { + "inplace": True + } + }, + # these need to be imported with locate in order to use them: + # `conv_op = pydoc.locate(architecture_kwargs['conv_op'])` + "_kw_requires_import": [ + "conv_op", + "norm_op", + "dropout_op", + "nonlin" + ] + } + del self.configuration["UNet_class_name"], self.configuration["UNet_base_num_features"], \ + self.configuration["n_conv_per_stage_encoder"], self.configuration["n_conv_per_stage_decoder"], \ + self.configuration["num_pool_per_axis"], self.configuration["pool_op_kernel_sizes"],\ + self.configuration["conv_kernel_sizes"], self.configuration["unet_max_num_features"] + self.configuration["architecture"] = arch_dict + def __repr__(self): return self.configuration.__repr__() @@ -77,49 +138,20 @@ def use_mask_for_norm(self) -> List[bool]: return self.configuration['use_mask_for_norm'] @property - def UNet_class_name(self) -> str: - return self.configuration['UNet_class_name'] - - @property - @lru_cache(maxsize=1) - def UNet_class(self) -> Type[nn.Module]: - unet_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), - self.UNet_class_name, - current_module="dynamic_network_architectures.architectures") - if unet_class is None: - raise RuntimeError('The network architecture specified by the plans file ' - 'is non-standard (maybe your own?). Fix this by not using ' - 'ConfigurationManager.UNet_class to instantiate ' - 'it (probably just overwrite build_network_architecture of your trainer.') - return unet_class - - @property - def UNet_base_num_features(self) -> int: - return self.configuration['UNet_base_num_features'] - - @property - def n_conv_per_stage_encoder(self) -> List[int]: - return self.configuration['n_conv_per_stage_encoder'] - - @property - def n_conv_per_stage_decoder(self) -> List[int]: - return self.configuration['n_conv_per_stage_decoder'] - - @property - def num_pool_per_axis(self) -> List[int]: - return self.configuration['num_pool_per_axis'] + def network_arch_class_name(self) -> str: + return self.configuration['architecture']['network_class_name'] @property - def pool_op_kernel_sizes(self) -> List[List[int]]: - return self.configuration['pool_op_kernel_sizes'] + def network_arch_init_kwargs(self) -> dict: + return self.configuration['architecture']['arch_kwargs'] @property - def conv_kernel_sizes(self) -> List[List[int]]: - return self.configuration['conv_kernel_sizes'] + def network_arch_init_kwargs_req_import(self) -> Union[Tuple[str, ...], List[str]]: + return self.configuration['architecture']['_kw_requires_import'] @property - def unet_max_num_features(self) -> int: - return self.configuration['unet_max_num_features'] + def pool_op_kernel_sizes(self) -> Tuple[Tuple[int, ...], ...]: + return self.configuration['architecture']['arch_kwargs']['strides'] @property @lru_cache(maxsize=1) diff --git a/pyproject.toml b/pyproject.toml index 91bc315..9a4c4cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nnunetv2" -version = "2.2.1" +version = "2.3.1" requires-python = ">=3.9" description = "nnU-Net is a framework for out-of-the box image segmentation." readme = "readme.md" @@ -31,8 +31,8 @@ keywords = [ ] dependencies = [ "torch>=2.0.0", - "acvl-utils>=0.2", - "dynamic-network-architectures>=0.2", + "acvl-utils>=0.2,<0.3", # 0.3 may bring breaking changes. Careful! + "dynamic-network-architectures>=0.2,<0.4", # 0.3.1 and lower are supported, 0.4 may have breaking changes. Let's be careful here "tqdm", "dicom2nifti", "scipy",