diff --git a/.github/workflows/gpu-test-action.yml b/.github/workflows/gpu-test-action.yml index 316532c..0590526 100644 --- a/.github/workflows/gpu-test-action.yml +++ b/.github/workflows/gpu-test-action.yml @@ -3,7 +3,7 @@ name: gpu-tests on: pull_request: push: - branches: main + branches: [dev, main] jobs: test-linux: diff --git a/README.md b/README.md index eb9ce29..045e384 100644 --- a/README.md +++ b/README.md @@ -9,15 +9,29 @@ The result is a deep-learning-based registration model that works well across da ![teaser](IntroFigure.jpg?raw=true) +**uniGradICON: A Foundation Model for Medical Image Registration** +Tian, Lin and Greer, Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard and Niethammer, Marc +_MICCAI 2024_ https://arxiv.org/abs/2403.05780 + +**multiGradICON: A Foundation Model for Multimodal Medical Image Registration** +Demir, Basar and Tian, Lin and Greer, Thomas Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard Jarrett and Ebrahim, Ebrahim and Niethammer, Marc +_MICCAI Workshop on Biomedical Image Registration (WBIR) 2024_ https://arxiv.org/abs/2408.00221 + Please (currently) cite as: ``` -@misc{tian2024unigradicon, - title={uniGradICON: A Foundation Model for Medical Image Registration}, - author={Lin Tian and Hastings Greer and Roland Kwitt and Francois-Xavier Vialard and Raul San Jose Estepar and Sylvain Bouix and Richard Rushmore and Marc Niethammer}, - year={2024}, - eprint={2403.05780}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@article{tian2024unigradicon, + title={uniGradICON: A Foundation Model for Medical Image Registration}, + author={Tian, Lin and Greer, Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard and Niethammer, Marc}, + journal={arXiv preprint arXiv:2403.05780}, + year={2024} +} +``` +``` +@article{demir2024multigradicon, + title={multiGradICON: A Foundation Model for Multimodal Medical Image Registration}, + author={Demir, Basar and Tian, Lin and Greer, Thomas Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard Jarrett and Ebrahim, Ebrahim and Niethammer, Marc}, + journal={arXiv preprint arXiv:2408.00221}, + year={2024} } ``` @@ -204,12 +218,25 @@ unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=Reg ``` -To register without instance optimization +To register without instance optimization (IO) ``` unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations None ``` -To warp +To use a different similarity measure in the IO. We currently support three similarity measures +- LNCC: lncc +- Squared LNCC: lncc2 +- MIND SSC: mind +``` +unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations 50 --io_sim lncc2 +``` + +To load specific model weight in the inference. We currently support uniGradICON and multiGradICON. +``` +unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --model multigradicon +``` + +To warp an image ``` unigradicon-warp --fixed [fixed_image_file_name] --moving [moving_image_file_name] --transform trans.hdf5 --warped_moving_out warped.nii.gz --linear ``` @@ -218,8 +245,12 @@ To warp a label map ``` unigradicon-warp --fixed [fixed_image_file_name] --moving [moving_image_segmentation_file_name] --transform trans.hdf5 --warped_moving_out warped_seg.nii.gz --nearest_neighbor ``` + We also provide a [colab](https://colab.research.google.com/drive/1JuFL113WN3FHCoXG-4fiBTWIyYpwGyGy?usp=sharing) demo. +## Slicer Extension + +A Slicer extensions is available [here](https://github.com/uncbiag/SlicerUniGradICON?tab=readme-ov-file) (and hopefully will soon be available via the Slicer Extension Manager). ## Plays well with others diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..994a988 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +icon_registration>=1.1.5 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index a0575b5..f50b557 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,8 @@ [metadata] name = unigradicon -version = 1.0.2 +version = 1.0.3 author = Lin Tian -author_email = +author_email = lintian@cs.unc.edu description = a foundation model for medical image registration long_description = file: README.md long_description_content_type = text/markdown @@ -21,7 +21,7 @@ packages = find: python_requires = >=3.7 install_requires = - icon_registration>=1.1.4 + icon_registration>=1.1.5 [options.packages.find] where = src diff --git a/src/unigradicon/__init__.py b/src/unigradicon/__init__.py index 4b8e9a3..032b86b 100644 --- a/src/unigradicon/__init__.py +++ b/src/unigradicon/__init__.py @@ -15,23 +15,27 @@ from icon_registration.mermaidlite import compute_warped_image_multiNC import icon_registration.itk_wrapper - - input_shape = [1, 1, 175, 175, 175] class GradientICONSparse(network_wrappers.RegistrationModule): - def __init__(self, network, similarity, lmbda): + def __init__(self, network, similarity, lmbda, use_label=False): super().__init__() self.regis_net = network self.lmbda = lmbda self.similarity = similarity + self.use_label = use_label - def forward(self, image_A, image_B): + def forward(self, image_A, image_B, label_A=None, label_B=None): assert self.identity_map.shape[2:] == image_A.shape[2:] assert self.identity_map.shape[2:] == image_B.shape[2:] + if self.use_label: + label_A = image_A if label_A is None else label_A + label_B = image_B if label_B is None else label_B + assert self.identity_map.shape[2:] == label_A.shape[2:] + assert self.identity_map.shape[2:] == label_B.shape[2:] # Tag used elsewhere for optimization. # Must be set at beginning of forward b/c not preserved by .cuda() etc @@ -75,10 +79,29 @@ def forward(self, image_A, image_B): 1, zero_boundary=True ) - - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) + self.similarity(self.warped_image_B, image_A) + + if self.use_label: + self.warped_label_A = compute_warped_image_multiNC( + torch.cat([label_A, inbounds_tag], axis=1) if inbounds_tag is not None else label_A, + self.phi_AB_vectorfield, + self.spacing, + 1, + ) + + self.warped_label_B = compute_warped_image_multiNC( + torch.cat([label_B, inbounds_tag], axis=1) if inbounds_tag is not None else label_B, + self.phi_BA_vectorfield, + self.spacing, + 1, + ) + + similarity_loss = self.similarity( + self.warped_label_A, label_B + ) + self.similarity(self.warped_label_B, label_A) + else: + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + self.similarity(self.warped_image_B, image_A) if len(self.input_shape) - 2 == 3: Iepsilon = ( @@ -142,8 +165,10 @@ def forward(self, image_A, image_B): def clean(self): del self.phi_AB, self.phi_BA, self.phi_AB_vectorfield, self.phi_BA_vectorfield, self.warped_image_A, self.warped_image_B + if self.use_label: + del self.warped_label_A, self.warped_label_B -def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5)): +def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5), use_label=False): dimension = len(input_shape) - 2 inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)) @@ -155,17 +180,44 @@ def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.L if include_last_step: inner_net = icon.TwoStepRegistration(inner_net, icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension))) - net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda) + net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda, use_label=use_label) net.assign_identity_map(input_shape) return net +def make_sim(similarity): + if similarity == "lncc": + return icon.LNCC(sigma=5) + elif similarity == "lncc2": + return icon. SquaredLNCC(sigma=5) + elif similarity == "mind": + return icon.MINDSSC(radius=2, dilation=2) + else: + raise ValueError(f"Similarity measure {similarity} not recognized. Choose from [lncc, lncc2, mind].") + +def get_multigradicon(loss_fn=icon.LNCC(sigma=5)): + net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn) + from os.path import exists + weights_location = "network_weights/multigradicon1.0/Step_2_final.trch" + if not exists(weights_location): + print("Downloading pretrained multigradicon model") + import urllib.request + import os + download_path = "https://github.com/uncbiag/uniGradICON/releases/download/multigradicon_weights/Step_2_final.trch" + os.makedirs("network_weights/multigradicon1.0/", exist_ok=True) + urllib.request.urlretrieve(download_path, weights_location) + print(f"Loading weights from {weights_location}") + trained_weights = torch.load(weights_location, map_location=torch.device("cpu")) + net.regis_net.load_state_dict(trained_weights) + net.to(config.device) + net.eval() + return net -def get_unigradicon(): - net = make_network(input_shape, include_last_step=True) +def get_unigradicon(loss_fn=icon.LNCC(sigma=5)): + net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn) from os.path import exists weights_location = "network_weights/unigradicon1.0/Step_2_final.trch" if not exists(weights_location): - print("Downloading pretrained model") + print("Downloading pretrained unigradicon model") import urllib.request import os download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch" @@ -177,6 +229,14 @@ def get_unigradicon(): net.eval() return net +def get_model_from_model_zoo(model_name="unigradicon", loss_fn=icon.LNCC(sigma=5)): + if model_name == "unigradicon": + return get_unigradicon(loss_fn) + elif model_name == "multigradicon": + return get_multigradicon(loss_fn) + else: + raise ValueError(f"Model {model_name} not recognized. Choose from [unigradicon, multigradicon].") + def quantile(arr: torch.Tensor, q): arr = arr.flatten() l = len(arr) @@ -202,7 +262,7 @@ def preprocess(image, modality="ct", segmentation=None): min_ = -1000 max_ = 1000 image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image) - image = itk.clamp_image_filter(image, Bounds=(-1000, 1000)) + image = itk.clamp_image_filter(image, Bounds=(min_, max_)) elif modality == "mri": image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image) min_, _ = itk.image_intensity_min_max(image) @@ -241,10 +301,14 @@ def main(): default=None, type=str, help="The path to save the warped image.") parser.add_argument("--io_iterations", required=False, default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.") + parser.add_argument("--io_sim", required=False, + default="lncc", help="The similarity measure used in IO. Default is LNCC. Choose from [lncc, lncc2, mind].") + parser.add_argument("--model", required=False, + default="unigradicon", help="The model to load. Default is unigradicon. Choose from [unigradicon, multigradicon].") args = parser.parse_args() - net = get_unigradicon() + net = get_model_from_model_zoo(args.model, make_sim(args.io_sim)) fixed = itk.imread(args.fixed) moving = itk.imread(args.moving) @@ -345,6 +409,3 @@ def maybe_cast(img: itk.Image): return img, maybe_cast_back - - - diff --git a/tests/test_command_arguments.py b/tests/test_command_arguments.py new file mode 100644 index 0000000..069839d --- /dev/null +++ b/tests/test_command_arguments.py @@ -0,0 +1,110 @@ +import itk +import numpy as np +import unittest +import icon_registration.test_utils + +import subprocess +import os +import torch + + +class TestCommandInterface(unittest.TestCase): + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + icon_registration.test_utils.download_test_data() + self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR + self.test_temp_dir = f"{self.test_data_dir}/temp" + os.makedirs(self.test_temp_dir, exist_ok=True) + self.device = torch.cuda.current_device() + + def test_register_unigradicon_inference(self): + subprocess.run([ + "unigradicon-register", + "--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz", + "--fixed_modality", "ct", + "--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz", + "--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz", + "--moving_modality", "ct", + "--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz", + "--transform_out", f"{self.test_temp_dir}/transform.hdf5", + "--io_iterations", "None" + ]) + + # load transform + phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0] + + assert isinstance(phi_AB, itk.CompositeTransform) + + insp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_iBH_xyz_r1.txt" + ) + ) + exp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_eBH_xyz_r1.txt" + ) + ) + + dists = [] + for i in range(len(insp_points)): + px, py = ( + insp_points[i], + np.array(phi_AB.TransformPoint(tuple(exp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + self.assertLess(np.mean(dists), 2.1) + + # remove temp file + os.remove(f"{self.test_temp_dir}/transform.hdf5") + + def test_register_multigradicon_inference(self): + subprocess.run([ + "unigradicon-register", + "--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz", + "--fixed_modality", "ct", + "--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz", + "--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz", + "--moving_modality", "ct", + "--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz", + "--transform_out", f"{self.test_temp_dir}/transform.hdf5", + "--io_iterations", "None", + "--model", "multigradicon" + ]) + + # load transform + phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0] + + assert isinstance(phi_AB, itk.CompositeTransform) + + insp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_iBH_xyz_r1.txt" + ) + ) + exp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_eBH_xyz_r1.txt" + ) + ) + + dists = [] + for i in range(len(insp_points)): + px, py = ( + insp_points[i], + np.array(phi_AB.TransformPoint(tuple(exp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + self.assertLess(np.mean(dists), 3.8) + + # remove temp file + os.remove(f"{self.test_temp_dir}/transform.hdf5") + + + diff --git a/tests/test_requirements_sync.py b/tests/test_requirements_sync.py new file mode 100644 index 0000000..f56b77a --- /dev/null +++ b/tests/test_requirements_sync.py @@ -0,0 +1,19 @@ +import unittest + + +class TestImports(unittest.TestCase): + + def test_requirements_match_cfg(self): + from inspect import getsourcefile + import os.path as path, sys + import configparser + + current_dir = path.dirname(path.abspath(getsourcefile(lambda: 0))) + parent_dir = current_dir[: current_dir.rfind(path.sep)] + + with open(parent_dir + "/requirements.txt") as f: + requirements_txt = "\n" + f.read() + requirements_cfg = configparser.ConfigParser() + requirements_cfg.read(parent_dir + "/setup.cfg") + requirements_cfg = requirements_cfg["options"]["install_requires"] + self.assertEqual(requirements_txt, requirements_cfg) diff --git a/training/dataset.py b/training/dataset.py new file mode 100644 index 0000000..ff948a0 --- /dev/null +++ b/training/dataset.py @@ -0,0 +1,291 @@ +import torch +import os +import torch.nn.functional as F +import random +import numpy as np +import itk +import glob +import SimpleITK as sitk +from tqdm import tqdm + +DATASET_DIR = "./data/uniGradICON/" + +class COPDDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + scale="2xdown", + data_path=f"{DATASET_DIR}/half_res_preprocessed_transposed_SI", + ROI_only=False, + data_num=-1, + desired_shape=None, + device="cpu" + ): + if phase == "debug": + phase = "train" + self.imgs = torch.load( + f"{data_path}/lungs_{phase}_{scale}_scaled", map_location="cpu" + ) + if data_num <= 0 or data_num > len(self.imgs): + self.data_num = len(self.imgs) + else: + self.data_num = data_num + self.imgs = self.imgs[: self.data_num] + + # Normalize to [0,1] + print("Processing COPD data.") + if ROI_only: + segs = torch.load( + f"{data_path}/lungs_seg_{phase}_{scale}_scaled", map_location="cpu" + )[: self.data_num] + self.imgs = list(map(lambda x: (self.process(x[0][0], desired_shape, device, x[1][0])[0],self.process(x[0][1], desired_shape, device, x[1][1])[0]), tqdm(zip(self.imgs, segs)))) + else: + self.imgs = list(map(lambda x: (self.process(x[0], desired_shape, device)[0],self.process(x[1], desired_shape, device)[0]), tqdm(self.imgs))) + + + def process(self, img, desired_shape=None, device="cpu", seg=None): + img = img.to(device) + im_min, im_max = torch.min(img), torch.max(img) + img = (img-im_min) / (im_max-im_min) + if seg is not None: + seg = seg.to(device) + img = (img * seg).float() + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + img_a, img_b = self.imgs[idx] + return img_a, img_b + + +class OAIDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + scale="2xdownsample", + data_path=f"{DATASET_DIR}/OAI", + data_num=1000, + desired_shape=None, + device="cpu" + ): + if phase == "debug": + phase = "train" + if phase == "test": + phase = "train" + print( + "WARNING: There is no validation set for OAI. Using train data for test set." + ) + self.imgs = torch.load( + f"{data_path}/knees_big_{scale}_train_set", map_location="cpu" + ) + + print("Processing OAI data.") + self.imgs = list(map(lambda x: self.process(x, desired_shape, device)[0], tqdm(self.imgs))) + + self.img_num = len(self.imgs) + + self.data_num = data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[idx_a] + img_b = self.imgs[idx_b] + return img_a, img_b + + +class HCPDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + scale="2xdown", + data_path=f"{DATASET_DIR}/HCP", + data_num=1000, + desired_shape=None, + device="cpu" + ): + if phase == "debug": + phase = "train" + if phase == "test": + phase = "train" + print( + "WARNING: There is no validation set for OAI. Using train data for test set." + ) + self.imgs = torch.load( + f"{data_path}/brain_train_{scale}_scaled", map_location="cpu" + ) + print("Processing HCP data.") + self.imgs = list(map(lambda x: self.process(x, desired_shape, device)[0], tqdm(self.imgs))) + + self.img_num = len(self.imgs) + + self.data_num = data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[idx_a] + img_b = self.imgs[idx_b] + return img_a, img_b + + +class L2rAbdomenDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/AbdomenCTCT", + data_num=1000, + desired_shape=None, + device="cpu" + ): + cases = list(map(lambda x: os.path.join(f"{data_path}/imagesTr", x), os.listdir(f"{data_path}/imagesTr/"))) + self.imgs = [] + print("Processing L2R Abdomen data.") + for i in tqdm(range(len(cases))): + case_path = cases[i] + self.imgs.append(self.process(torch.tensor(np.asarray(itk.imread(case_path)))[None, None], desired_shape, device)[0]) + + self.img_num = len(self.imgs) + + self.data_num = data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[idx_a] + img_b = self.imgs[idx_b] + return img_a, img_b + + +class L2rThoraxCBCTDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/ThoraxCBCT", + data_num=1000, + desired_shape=None, + device="cpu" + ): + import json + with open(f"{data_path}/ThoraxCBCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + cases = [[f"{data_path}/{c['moving']}", f"{data_path}/{c['fixed']}"] for c in data_info["training_paired_images"]] + self.imgs = [] + print("Processing L2R ThoraxCBCT data.") + for i in tqdm(range(len(cases))): + moving_path, fixed_path = cases[i] + self.imgs.append( + ( + self.process(torch.tensor(np.asarray(itk.imread(moving_path)))[None, None], desired_shape, device)[0], + self.process(torch.tensor(np.asarray(itk.imread(fixed_path)))[None, None], desired_shape, device)[0] + )) + + if data_num < len(self.imgs): + self.imgs = self.imgs[:data_num] + self.data_num = len(self.imgs) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + img_a, img_b = self.imgs[idx] + return img_a, img_b + + +class ACDCDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/ACDC", + desired_shape=None + ): + self.imgs = [] + training_files = sorted(glob(f"{data_path}/database/training/*/*4d.nii.gz")) + + for file in training_files: + self.imgs.append( + self.process( + torch.tensor( + sitk.GetArrayFromImage(sitk.ReadImage(file, sitk.sitkFloat32)) + )[None], + desired_shape) + ) + + self.img_num = len(self.imgs) + + + def process(self, img, desired_shape=None): + img = img / torch.amax(img, dim=(1,2,3), keepdim=True) + + # Pad the image + pad_size = (img.shape[2] - img.shape[1]) // 2 + img = torch.pad(img, (0, 0, 0, 0, pad_size, pad_size), "constant", 0) + + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img + + def __len__(self): + return self.img_num + + def __getitem__(self, idx): + img = self.imgs[idx] + img_count = img.shape[0] + idx_a = random.randint(0, img_count - 1) + idx_b = random.randint(0, img_count - 1) + img_a = img[idx_a] + img_b = img[idx_b] + return img_a, img_b + + +if __name__ == "__main__": + from torch.utils.data import DataLoader + datasets = [COPDDataset, OAIDataset, HCPDataset, L2rAbdomenDataset] + + for dataset in datasets: + data = dataset(desired_shape=(64, 64, 64)) + data = DataLoader(data, batch_size=3) + print(next(iter(data))[0].shape) diff --git a/training/dataset_multi.py b/training/dataset_multi.py new file mode 100644 index 0000000..08cf330 --- /dev/null +++ b/training/dataset_multi.py @@ -0,0 +1,813 @@ +import torch +import os +import torch.nn.functional as F +import random +import numpy as np +import itk +import glob +import SimpleITK as sitk +from tqdm import tqdm +import blosc +import json + +blosc.set_nthreads(1) +DATASET_DIR = "./data/multiGradICON/" + +class COPDDataset(torch.utils.data.Dataset): + def __init__( + self, + scale="2xdown", + data_path=f"{DATASET_DIR}/half_res_preprocessed_transposed_SI", + ROI_only=False, + data_num=-1, + desired_shape=None, + device="cpu", + return_labels=False + ): + self.imgs = torch.load(f"{data_path}/lungs_train_{scale}_scaled", map_location="cpu") + self.data_num = data_num + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.modalities = ['ct'] + self.anatomies = ['lung'] + self.region_num = 1 + + print("Processing COPD data.") + if ROI_only: + segs = torch.load( + f"{data_path}/lungs_seg_train_{scale}_scaled", map_location="cpu" + )[: self.data_num] + self.imgs = list(map(lambda x: (self.pack_and_process_image(x[0][0], x[1][0]), self.pack_and_process_image(x[0][1], x[1][1])), tqdm(zip(self.imgs, segs)))) + else: + self.imgs = list(map(lambda x: (self.pack_and_process_image(x[0]), self.pack_and_process_image(x[1])), tqdm(self.imgs))) + + def pack_and_process_image(self, img, seg=None): + processed_image = self.process(img, self.desired_shape, self.device, seg)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu", seg=None): + img = img.to(device) + im_min, im_max = torch.min(img), torch.max(img) + img = (img-im_min) / (im_max-im_min) + if seg is not None: + seg = seg.to(device) + img = (img * seg).float() + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + img_a, img_b = random.choice(self.imgs) + + if self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(img_a), blosc.unpack_array(img_b) + else: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + +class BratsRegDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/BraTS-Reg/BraTSReg_Training_Data_v3/", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random' + ): + + folders = sorted(glob.glob(data_path + '*/')) + self.pre_images, self.post_images = {'t1': [], 't1ce': [], 't2': [], 'flair': []}, {'t1': [], 't1ce': [], 't2': [], 'flair': []} + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['brain'] + self.region_num = 1 + + for folder in tqdm(folders): + t1 = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_t1.nii.gz')[0]) + t1ce = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_t1ce.nii.gz')[0]) + t2 = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_t2.nii.gz')[0]) + flair = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_flair.nii.gz')[0]) + + self.pre_images['t1'].append(self.pack_and_process_image(t1)) + self.pre_images['t1ce'].append(self.pack_and_process_image(t1ce)) + self.pre_images['t2'].append(self.pack_and_process_image(t2)) + self.pre_images['flair'].append(self.pack_and_process_image(flair)) + + t1 = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_t1.nii.gz')[0]) + t1ce = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_t1ce.nii.gz')[0]) + t2 = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_t2.nii.gz')[0]) + flair = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_flair.nii.gz')[0]) + + self.post_images['t1'].append(self.pack_and_process_image(t1)) + self.post_images['t1ce'].append(self.pack_and_process_image(t1ce)) + self.post_images['t2'].append(self.pack_and_process_image(t2)) + self.post_images['flair'].append(self.pack_and_process_image(flair)) + + self.data_num = data_num + self.image_num = len(self.pre_images['t1']) + self.modalities = list(self.pre_images.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx1 = random.randint(0, self.image_num-1) + + img_a = self.pre_images[random.choice(self.modalities)][idx1] + img_b = self.post_images[random.choice(self.modalities)][idx1] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label_a = self.pre_images[random.choice(self.modalities)][idx1] + label_b = self.post_images[random.choice(self.modalities)][idx1] + else: + modality = random.choice(self.modalities) + label_a = self.pre_images[modality][idx1] + label_b = self.post_images[modality][idx1] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label_a), blosc.unpack_array(label_b) + +class L2rAbdomenDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/AbdomenCTCT", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random', + augmentation = True + ): + cases = list(map(lambda x: os.path.join(f"{data_path}/imagesTr", x), os.listdir(f"{data_path}/imagesTr/"))) + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['abdomen'] + self.region_num = 1 + + self.imgs = {'ct' : []} + if augmentation: + self.imgs['1-ct'] = [] + + print("Processing L2R Abdomen data.") + for i in tqdm(range(len(cases))): + case_path = cases[i] + self.imgs['ct'].append(self.pack_and_process_image(case_path)) + if augmentation: + self.imgs['1-ct'].append(self.pack_and_process_image(case_path, invert=True)) + + self.img_num = len(self.imgs) + self.data_num = data_num + self.modalities = list(self.imgs.keys()) + + def pack_and_process_image(self, case_path, invert=False): + processed_image = self.process(torch.tensor(np.asarray(itk.imread(case_path)))[None, None], self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + if invert: + array_image = 1 - array_image + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[random.choice(self.modalities)][idx_a] + img_b = self.imgs[random.choice(self.modalities)][idx_b] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label_a = self.imgs[random.choice(self.modalities)][idx_a] + label_b = self.imgs[random.choice(self.modalities)][idx_b] + else: + modality = random.choice(self.modalities) + label_a = self.imgs[modality][idx_a] + label_b = self.imgs[modality][idx_b] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label_a), blosc.unpack_array(label_b) + +class HCPDataset(torch.utils.data.Dataset): + def __init__( + self, + scale="2xdown", + data_path=f"{DATASET_DIR}/ICON_brain_preprocessed_data", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random' + ): + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['brain'] + self.region_num = 1 + + imgsT1 = torch.load( + f"{data_path}/brain_train_{scale}_scaled_T1", map_location="cpu" + ) + imgsT2 = torch.load( + f"{data_path}/brain_train_{scale}_scaled_T2", map_location="cpu" + ) + imgsT1 = list(map(lambda x: self.pack_and_process_image(x), tqdm(imgsT1))) + imgsT2 = list(map(lambda x: self.pack_and_process_image(x), tqdm(imgsT2))) + + self.imgs = {'T1': imgsT1, 'T2': imgsT2} + + self.img_num = len(imgsT1) + self.data_num = data_num + self.modalities = list(self.imgs.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(image, self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + + img_a = self.imgs[random.choice(self.modalities)][idx_a] + img_b = self.imgs[random.choice(self.modalities)][idx_b] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label_a = self.imgs[random.choice(self.modalities)][idx_a] + label_b = self.imgs[random.choice(self.modalities)][idx_b] + else: + modality = random.choice(self.modalities) + label_a = self.imgs[modality][idx_a] + label_b = self.imgs[modality][idx_b] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label_a), blosc.unpack_array(label_b) + +class ABCDFAMDDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + data_path=f"{DATASET_DIR}/dti_scalars", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random' + ): + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.images = {} + self.anatomies = ['brain'] + self.region_num = 1 + + md_files = sorted(glob.glob(f'{data_path}/md/' + '*.nii.gz')) + fa_files = sorted(glob.glob(f'{data_path}/fa/' + '*.nii.gz')) + + if phase == "train": + fa_files = fa_files[10:] + md_files = md_files[10:] + elif phase == "val": + fa_files = fa_files[:10] + md_files = md_files[:10] + + fa_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in fa_files} + md_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in md_files} + + for fa_id in tqdm(fa_ids): + if fa_id not in md_ids: + continue + + fa = itk.imread(fa_ids[fa_id]) + md = itk.imread(md_ids[fa_id]) + self.images[fa_id] = {'FA': self.pack_and_process_image(fa), 'MD': self.pack_and_process_image(md)} + + self.data_num = data_num + self.image_ids = list(self.images.keys()) + self.modalities = ['FA', 'MD'] + + def pack_and_process_image(self, image): + processed_image = self.process(self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0]) + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + index1 = random.choice(self.image_ids) + image1 = self.images[index1][random.choice(self.modalities)] + + index2 = random.choice(self.image_ids) + image2 = self.images[index2][random.choice(self.modalities)] + + if not self.return_labels: + return blosc.unpack_array(image1), blosc.unpack_array(image2) + + if self.randomization == 'random': + label1 = self.images[index1][random.choice(self.modalities)] + label2 = self.images[index2][random.choice(self.modalities)] + else: + modality = random.choice(self.modalities) + label1 = self.images[index1][modality] + label2 = self.images[index2][modality] + + return blosc.unpack_array(image1), blosc.unpack_array(image2), blosc.unpack_array(label1), blosc.unpack_array(label2) + +class ABCDDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + data_path=f"{DATASET_DIR}", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + ): + + md_path = f'{data_path}/dti_scalars/md/' + md_files = sorted(glob.glob(md_path + '*.nii.gz')) + + fa_path = f'{data_path}/dti_scalars/fa/' + fa_files = sorted(glob.glob(fa_path + '*.nii.gz')) + + mri_path = f'{data_path}/structural_mri/' + mri_files = sorted(glob.glob(mri_path + '*_oriented_stripped.nii.gz')) + + if phase == "train": + fa_files = fa_files[10:] + md_files = md_files[10:] + mri_files = mri_files[10:] + elif phase == "val": + fa_files = fa_files[:10] + md_files = md_files[:10] + mri_files = mri_files[:10] + + fa_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in fa_files} + md_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in md_files} + mri_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in mri_files} + + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.images = {'FA': [], 'MD': [], 'T1': [], 'T2': []} + self.anatomies = ['brain'] + self.region_num = 1 + + for id in tqdm(fa_ids): + fa = itk.imread(fa_ids[id]) + self.images['FA'].append(self.pack_and_process_image(fa)) + md = itk.imread(md_ids[id]) + self.images['MD'].append(self.pack_and_process_image(md)) + + for mri_id in tqdm(mri_ids): + mri = mri_ids[mri_id].replace('T1w', 'modality').replace('T2w', 'modality') + if not os.path.exists(mri.replace('modality', 'T1w')) or not os.path.exists(mri.replace('modality', 'T2w')): + continue + + t1_mri = itk.imread(mri.replace('modality', 'T1w')) + self.images['T1'].append(self.pack_and_process_image(t1_mri)) + + t2_mri = itk.imread(mri.replace('modality', 'T2w')) + self.images['T2'].append(self.pack_and_process_image(t2_mri)) + + self.data_num = data_num + self.modalities = list(self.images.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0]) + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + image1 = np.random.choice(self.images[np.random.choice(self.modalities)]) + image2 = np.random.choice(self.images[np.random.choice(self.modalities)]) + + if not self.return_labels: + return blosc.unpack_array(image1), blosc.unpack_array(image2) + else: + return blosc.unpack_array(image1), blosc.unpack_array(image2), blosc.unpack_array(image1), blosc.unpack_array(image2) + + +class OAIMMDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/oai", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + ): + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.anatomies = ['knee'] + self.region_num = 1 + + dataset_dess = torch.load(f"{data_path}/dess_images.pt", map_location="cpu") + dataset_T2 = torch.load(f"{data_path}t2_images.pt", map_location="cpu") + + self.images = {'DESS': [], 'T2': []} + + self.images['DESS'] = list(map(lambda x: self.pack_and_process_image(x), tqdm(dataset_dess))) + self.images['T2'] = list(map(lambda x: self.pack_and_process_image(x), tqdm(dataset_T2))) + + self.data_num = data_num + self.modalities = list(self.images.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(image, self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + modality_1, modality_2 = random.choices(self.modalities, k=2) + img_a = random.choice(self.images[modality_1]) + img_b = random.choice(self.images[modality_2]) + + if self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(img_a), blosc.unpack_array(img_b) + else: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + +class L2rMRCTDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/AbdomenMRCT/", + data_num=1000, + desired_shape=None, + device="cpu", + phase = "train", + augmentation = True, + return_labels=False + ): + #inter-patient + self.phase = phase + self.device = device + self.return_labels = return_labels + self.anatomies = ['abdomen'] + self.region_num = 1 + + with open(f"{data_path}/AbdomenMRCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + + if self.phase == "train": + mr_samples = [c["image"] for c in data_info["training"]["0"]] #mr + ct_samples = [c["image"] for c in data_info["training"]["1"]] #ct + else: + mr_samples = [c["fixed"] for c in data_info["registration_test"]] + ct_samples = [c["moving"] for c in data_info["registration_test"]] + + self.images = {'mr' : [], 'ct' : []} + + if augmentation: + self.images['1-ct'] = [] + + for path in tqdm(mr_samples): + image = np.asarray(itk.imread(os.path.join(data_path, path))) + image = torch.Tensor(np.array(image)).unsqueeze(0).unsqueeze(0) + image = self.process_mr(image, desired_shape, device)[0] + self.images['mr'].append(self.pack(image)) + + for path in tqdm(ct_samples): + image = np.asarray(itk.imread(os.path.join(data_path, path))) + image = torch.Tensor(np.array(image)).unsqueeze(0).unsqueeze(0) + image = self.process_ct(image, desired_shape, device)[0] + self.images['ct'].append(self.pack(image)) + if augmentation: + self.images['1-ct'].append(self.pack(1-image)) + + self.data_num = data_num + self.modalities = list(self.images.keys()) + + def pack(self, image): + return blosc.pack_array(image.numpy()) + + def process_label(self, label, desired_shape=None, device="cpu"): + label = label.to(device) + if desired_shape is not None: + label = F.interpolate(label, desired_shape, mode="nearest") + return label.cpu() + + def process_ct(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def process_mr(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + modality1 = random.choice(self.modalities) + modality2 = random.choice(self.modalities) + + idx1 = random.randint(0, len(self.images[modality1])-1) + idx2 = random.randint(0, len(self.images[modality2])-1) + + img_a = self.images[modality1][idx1] + img_b = self.images[modality2][idx2] + + if self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(img_a), blosc.unpack_array(img_b) + else: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + +class UKBiobankDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/uk-biobank/", + data_num=1000, + desired_shape=None, + device="cpu", + phase = "train", + return_labels=False, + randomization = 'random' + ): + #contains 6 regions of the body - each region contains a list of images + fat_weighted_regions = torch.load(data_path + 'regions_fat.pt', map_location='cpu') + water_weighted_regions = torch.load(data_path + 'regions_water.pt', map_location='cpu') + + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['abdomen', 'lung', 'knee'] + + self.images = {'fat': fat_weighted_regions, 'water': water_weighted_regions} + + for region in tqdm(range(6)): + if phase == "train": + self.images['fat'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['fat'][region][10:])) + self.images['water'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['water'][region][10:])) + else: + self.images['fat'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['fat'][region][:10])) + self.images['water'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['water'][region][:10])) + + self.data_num = data_num + self.modalities = ['fat', 'water'] + self.region_num = 6 + + def pack_and_process_image(self, image): + processed_image = self.process(image, self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = torch.tensor(sitk.GetArrayFromImage(img)[None, None].astype(np.float32)) + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + region = random.randint(0, self.region_num-1) + modality1 = random.choice(self.modalities) + modality2 = random.choice(self.modalities) + + index1 = random.randint(0, len(self.images[modality1][region])-1) + index2 = random.randint(0, len(self.images[modality2][region])-1) + + img_a = self.images[modality1][region][index1] + img_b = self.images[modality2][region][index2] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label1 = self.images[random.choice(self.modalities)][region][index1] + label2 = self.images[random.choice(self.modalities)][region][index2] + else: + modality = random.choice(self.modalities) + label1 = self.images[modality][region][index1] + label2 = self.images[modality][region][index2] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label1), blosc.unpack_array(label2) + + +class PancreasDataset(torch.utils.data.Dataset): + def __init__(self, + phase="train", + data_path=f"{DATASET_DIR}/pancreas/", + data_num=1000, + desired_shape=(175, 175, 175), + device="cpu", + return_labels=False, + ): + + with open(f"{data_path}/{phase}_set.txt", 'r') as f: + self.img_list = [line.strip() for line in f] + + self.desired_shape = desired_shape + self.data_num = data_num + self.img_dict = {} + self.return_labels = return_labels + self.modalities = ['ct', 'cbct'] + self.region_num = 1 + self.anatomies = ['abdomen'] + + for idx in range(len(self.img_list)): + ith_info = self.img_list[idx].split(" ") + ct_img_name = ith_info[0] + cb_img_name = ith_info[1] + + ct_img_itk = sitk.ReadImage(ct_img_name) + cb_img_itk = sitk.ReadImage(cb_img_name) + + ct_img_arr = sitk.GetArrayFromImage(ct_img_itk) + cb_img_arr = sitk.GetArrayFromImage(cb_img_itk) + + ct_img_arr, cb_img_arr = self.process_training_data(ct_img_arr, cb_img_arr) + + self.img_dict[ct_img_name] = blosc.pack_array(ct_img_arr) + self.img_dict[cb_img_name] = blosc.pack_array(cb_img_arr) + + def __len__(self): + return self.data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = torch.tensor(img).unsqueeze(0).unsqueeze(0).cpu() + img = img.to(device).float() + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.squeeze().numpy().cpu().float() + + def process_training_data(self, ct_img_arr, cb_img_arr): + ct_img_arr = self.process(ct_img_arr, self.desired_shape) + cb_img_arr = self.process(cb_img_arr, self.desired_shape) + + return ct_img_arr, cb_img_arr + + def __getitem__(self, idx): + idx = np.random.randint(0, len(self.img_list)) + ith_info = self.img_list[idx].split(" ") + ct_img_name = ith_info[0] + cb_img_name = ith_info[1] + + ct_img_arr = blosc.unpack_array(self.img_dict[ct_img_name]) + cb_img_arr = blosc.unpack_array(self.img_dict[cb_img_name]) + + if not self.return_labels: + return ct_img_arr, cb_img_arr + else: + return ct_img_arr, cb_img_arr, ct_img_arr, cb_img_arr + +class L2rThoraxCBCTDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/ThoraxCBCT", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + ): + + with open(f"{data_path}/ThoraxCBCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + cases = [f"{data_path}/{c['image']}" for c in data_info["training"]] + + self.modalities = {'0000': [] , '0001': [], '0002': []} + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.anatomies = ['lung'] + self.region_num = 1 + + for case in cases: + modality = case.split('/')[-1].split('_')[-1].split('.')[0] + self.modalities[modality].append(self.pack_and_process_image(itk.imread(case))) + + self.data_num = data_num + + def pack_and_process_image(self, image): + processed_image = self.process(self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0]) + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + modality1 = random.choice(list(self.modalities.keys())) + modality2 = random.choice(list(self.modalities.keys())) + + patient_id = random.randint(0, len(self.modalities[modality1])-1) + + image1 = self.modalities[modality1][patient_id] + image2 = self.modalities[modality2][patient_id] + + if not self.return_labels: + return blosc.unpack_array(image1), blosc.unpack_array(image2) + else: + return blosc.unpack_array(image1), blosc.unpack_array(image2), blosc.unpack_array(image1), blosc.unpack_array(image2) \ No newline at end of file diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..80ec6b2 --- /dev/null +++ b/training/train.py @@ -0,0 +1,297 @@ +import os +import random +from datetime import datetime + +from tqdm import tqdm +import torch +import torch.nn.functional as F +from dataset import COPDDataset, HCPDataset, OAIDataset, L2rAbdomenDataset +from torch.utils.data import ConcatDataset, DataLoader + +import icon_registration as icon +import icon_registration.networks as networks +from icon_registration.losses import ICONLoss, to_floats + +from unigradicon import make_network + +def write_stats(writer, stats: ICONLoss, ite, prefix=""): + for k, v in to_floats(stats)._asdict().items(): + writer.add_scalar(f"{prefix}{k}", v, ite) + +input_shape = [1, 1, 175, 175, 175] + +BATCH_SIZE= 4 +device_ids = [1, 0, 2, 3] +GPUS = len(device_ids) +EXP_DIR = "./results/unigradicon/" + +def get_dataset(): + data_num = 1000 + return ConcatDataset( + ( + COPDDataset(desired_shape=input_shape[2:], device=device_ids[0], ROI_only=True, data_num=data_num), + OAIDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=data_num), + HCPDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=data_num), + L2rAbdomenDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=data_num) + ) + ) + +def augment(image_A, image_B): + device = image_A.device + identity_list = [] + for i in range(image_A.shape[0]): + identity = torch.tensor([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], device=device) + idxs = set((0, 1, 2)) + for j in range(3): + k = random.choice(list(idxs)) + idxs.remove(k) + identity[0, j, k] = 1 + identity = identity * (torch.randint_like(identity, 0, 2, device=device) * 2 - 1) + identity_list.append(identity) + + identity = torch.cat(identity_list) + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_A.shape[1] > 1: + # Then we have segmentations + warped_A = F.grid_sample(image_A[:, :1], forward_grid, padding_mode='border') + warped_A_seg = F.grid_sample(image_A[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_A = torch.cat([warped_A, warped_A_seg], axis=1) + else: + warped_A = F.grid_sample(image_A, forward_grid, padding_mode='border') + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_B.shape[1] > 1: + # Then we have segmentations + warped_B = F.grid_sample(image_B[:, :1], forward_grid, padding_mode='border') + warped_B_seg = F.grid_sample(image_B[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_B = torch.cat([warped_B, warped_B_seg], axis=1) + else: + warped_B = F.grid_sample(image_B, forward_grid, padding_mode='border') + + return warped_A, warped_B + +def train_kernel(optimizer, net, moving_image, fixed_image, writer, ite): + optimizer.zero_grad() + loss_object = net(moving_image, fixed_image) + loss = torch.mean(loss_object.all_loss) + loss.backward() + optimizer.step() + # print(to_floats(loss_object)) + write_stats(writer, loss_object, ite, prefix="train/") + +def train( + net, + optimizer, + data_loader, + val_data_loader, + epochs=200, + eval_period=-1, + save_period=-1, + step_callback=(lambda net: None), + unwrapped_net=None, + data_augmenter=None, +): + """A training function intended for long running experiments, with tensorboard logging + and model checkpoints. Use for medical registration training + """ + import footsteps + from torch.utils.tensorboard import SummaryWriter + + if unwrapped_net is None: + unwrapped_net = net + + loss_curve = [] + writer = SummaryWriter( + footsteps.output_dir + "/logs/" + datetime.now().strftime("%Y%m%d-%H%M%S"), + flush_secs=30, + ) + + iteration = 0 + for epoch in tqdm(range(epochs)): + for moving_image, fixed_image in data_loader: + moving_image, fixed_image = moving_image.cuda(), fixed_image.cuda() + if data_augmenter is not None: + with torch.no_grad(): + moving_image, fixed_image = data_augmenter(moving_image, fixed_image) + train_kernel(optimizer, net, moving_image, fixed_image, + writer, iteration) + iteration += 1 + + step_callback(unwrapped_net) + + + if epoch % save_period == 0: + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + + if epoch % eval_period == 0: + visualization_moving, visualization_fixed = next(iter(val_data_loader)) + visualization_moving, visualization_fixed = visualization_moving[:, :1].cuda(), visualization_fixed[:, :1].cuda() + unwrapped_net.eval() + warped = [] + with torch.no_grad(): + eval_loss = unwrapped_net(visualization_moving, visualization_fixed) + write_stats(writer, eval_loss, epoch, prefix="val/") + warped = unwrapped_net.warped_image_A.cpu() + del eval_loss + unwrapped_net.clean() + unwrapped_net.train() + + def render(im): + if len(im.shape) == 5: + im = im[:, :, :, im.shape[3] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + writer.add_images( + "moving_image", render(visualization_moving[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "fixed_image", render(visualization_fixed[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "warped_moving_image", + render(warped), + epoch, + dataformats="NCHW", + ) + writer.add_images( + "difference", + render(torch.clip((warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1)), + epoch, + dataformats="NCHW", + ) + + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + +def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period, resume_from): + + net = make_network(input_shape, include_last_step=False) + + torch.cuda.set_device(device_ids[0]) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # device = f"cuda:{device_ids[0]}" + + # Continue train + if resume_from != "": + print("Resume from: ", resume_from) + net.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + if GPUS == 1: + net_par = net.cuda() + else: + net_par = torch.nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_par.parameters(), lr=0.00005) + + if resume_from != "": + optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_par.train() + + print("start train.") + train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, + epochs=epochs[0], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_1_final.trch", + ) + + net_2 = make_network(input_shape, include_last_step=True) + + net_2.regis_net.netPhi.load_state_dict(net.regis_net.state_dict()) + + # Continue train + # if resume_from != "": + # print("Resume from: ", resume_from) + # net_2.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + del net + del net_par + del optimizer + + if GPUS == 1: + net_2_par = net_2.cuda() + else: + net_2_par = torch.nn.DataParallel(net_2, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_2_par.parameters(), lr=0.00005) + + # if resume_from != "": + # optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_2_par.train() + + # We're being weird by training two networks in one script. This hack keeps + # the second training from overwriting the outputs of the first. + footsteps.output_dir_impl = footsteps.output_dir + "2nd_step/" + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train(net_2_par, optimizer, data_loader, val_data_loader, unwrapped_net=net_2, epochs=epochs[1], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net_2.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_2_final.trch", + ) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--resume_from", required=False, default="") + args = parser.parse_args() + resume_from = args.resume_from + + import footsteps + footsteps.initialize(output_root=EXP_DIR) + + dataset = get_dataset() + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE*GPUS, + shuffle=True, + num_workers=4, + drop_last=True, + ) + val_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=4, + drop_last=True, + ) + print("Finish data loading...") + + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train_two_stage(input_shape, dataloader, val_dataloader, GPUS, [801,201], 20, 20, resume_from) \ No newline at end of file diff --git a/training/train_multi.py b/training/train_multi.py new file mode 100644 index 0000000..01de562 --- /dev/null +++ b/training/train_multi.py @@ -0,0 +1,396 @@ +import os +import random +from datetime import datetime + +from tqdm import tqdm +import torch +import torch.nn.functional as F +import dataset_multi +from torch.utils.data import ConcatDataset, DataLoader +from icon_registration.losses import ICONLoss, to_floats +from unigradicon import make_network +import math + +def write_stats(writer, stats: ICONLoss, ite, prefix=""): + for k, v in to_floats(stats)._asdict().items(): + writer.add_scalar(f"{prefix}{k}", v, ite) + +input_shape = [1, 1, 175, 175, 175] +DATA_NUM = 4000 +BATCH_SIZE= 4 +device_ids = [1, 0, 2, 3] +GPUS = len(device_ids) +EXP_DIR = "./results/multigradicon/" + +def get_multi_training_set(): + randomization = 'random' + + datasets = [ + dataset_multi.COPDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, ROI_only=False), + dataset_multi.BratsRegDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.L2rAbdomenDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization, augmentation = True), + dataset_multi.HCPDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.ABCDFAMDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.OAIMMDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.L2rMRCTDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, augmentation = True), + dataset_multi.UKBiobankDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + ] + + dataset = ConcatDataset(datasets) + + #modality based weighting + dataset_weights = [] + for dataset in datasets: + weight = [dataset.region_num * math.comb(len(dataset.modalities) + 1, 2)] * len(dataset) + dataset_weights.append(weight) + + return dataset, dataset_weights + +def get_multi_finetuning_set(): + randomization = 'random' + + datasets = [ + dataset_multi.PancreasDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.L2rThoraxCBCTDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.ABCDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.COPDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, ROI_only=True), + dataset_multi.BratsRegDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.L2rAbdomenDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization, augmentation = False), + dataset_multi.HCPDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.OAIMMDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.L2rMRCTDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, augmentation = False), + dataset_multi.UKBiobankDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + ] + + dataset = ConcatDataset(datasets) + + #anatomic region based weighting + anatomies_count = {} + for dataset in datasets: + for anatomy in dataset.anatomies: + if anatomy not in anatomies_count: + anatomies_count[anatomy] = 0 + anatomies_count[anatomy] += 1 + + max_anatomy = max(anatomies_count.values()) + + dataset_weights = [] + for dataset in datasets: + weight = 0 + for anatomy in dataset.anatomies: + weight += max_anatomy / anatomies_count[anatomy] + dataset_weights.append([weight] * len(dataset)) + + return dataset, dataset_weights + +def augment(image_A, image_B, label_A, label_B): + device = image_A.device + identity_list = [] + for i in range(image_A.shape[0]): + identity = torch.tensor([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], device=device) + idxs = set((0, 1, 2)) + for j in range(3): + k = random.choice(list(idxs)) + idxs.remove(k) + identity[0, j, k] = 1 + identity = identity * (torch.randint_like(identity, 0, 2, device=device) * 2 - 1) + identity_list.append(identity) + + identity = torch.cat(identity_list) + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_A.shape[1] > 1: + # Then we have segmentations + warped_A = F.grid_sample(image_A[:, :1], forward_grid, padding_mode='border') + warped_A_seg = F.grid_sample(image_A[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_A = torch.cat([warped_A, warped_A_seg], axis=1) + else: + warped_A = F.grid_sample(image_A, forward_grid, padding_mode='border') + warped_label_A = F.grid_sample(label_A, forward_grid, padding_mode='border') + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_B.shape[1] > 1: + # Then we have segmentations + warped_B = F.grid_sample(image_B[:, :1], forward_grid, padding_mode='border') + warped_B_seg = F.grid_sample(image_B[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_B = torch.cat([warped_B, warped_B_seg], axis=1) + else: + warped_B = F.grid_sample(image_B, forward_grid, padding_mode='border') + warped_label_B = F.grid_sample(label_B, forward_grid, padding_mode='border') + + return warped_A, warped_B, warped_label_A, warped_label_B + +def train_kernel(optimizer, net, moving_image, fixed_image, moving_label, fixed_label, writer, ite): + optimizer.zero_grad() + loss_object = net(moving_image, fixed_image, moving_label, fixed_label) + loss = torch.mean(loss_object.all_loss) + loss.backward() + optimizer.step() + # print(to_floats(loss_object)) + write_stats(writer, loss_object, ite, prefix="train/") + +def train( + net, + optimizer, + data_loader, + val_data_loader, + epochs=200, + eval_period=-1, + save_period=-1, + step_callback=(lambda net: None), + unwrapped_net=None, + data_augmenter=None, +): + """A training function intended for long running experiments, with tensorboard logging + and model checkpoints. Use for medical registration training + """ + import footsteps + from torch.utils.tensorboard import SummaryWriter + + if unwrapped_net is None: + unwrapped_net = net + + loss_curve = [] + writer = SummaryWriter( + footsteps.output_dir + "/logs/" + datetime.now().strftime("%Y%m%d-%H%M%S"), + flush_secs=30, + ) + + iteration = 0 + for epoch in tqdm(range(epochs)): + for moving_image, fixed_image, moving_label, fixed_label in data_loader: + moving_image, fixed_image, moving_label, fixed_label = moving_image.cuda(), fixed_image.cuda(), moving_label.cuda(), fixed_label.cuda() + if data_augmenter is not None: + with torch.no_grad(): + moving_image, fixed_image, moving_label, fixed_label = data_augmenter(moving_image, fixed_image, moving_label, fixed_label) + train_kernel(optimizer, net, moving_image, fixed_image, moving_label, fixed_label, writer, iteration) + iteration += 1 + + step_callback(unwrapped_net) + + + if epoch % save_period == 0: + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + + if epoch % eval_period == 0: + visualization_moving, visualization_fixed, _, _ = next(iter(val_data_loader)) + visualization_moving, visualization_fixed = visualization_moving[:, :1].cuda(), visualization_fixed[:, :1].cuda() + unwrapped_net.eval() + warped = [] + with torch.no_grad(): + eval_loss = unwrapped_net(visualization_moving, visualization_fixed) + write_stats(writer, eval_loss, epoch, prefix="val/") + warped = unwrapped_net.warped_image_A.cpu() + del eval_loss + unwrapped_net.clean() + unwrapped_net.train() + + def render(im): + if len(im.shape) == 5: + im = im[:, :, :, im.shape[3] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + writer.add_images( + "moving_image", render(visualization_moving[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "fixed_image", render(visualization_fixed[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "warped_moving_image", + render(warped), + epoch, + dataformats="NCHW", + ) + writer.add_images( + "difference", + render(torch.clip((warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1)), + epoch, + dataformats="NCHW", + ) + + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + +def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period, resume_from): + + net = make_network(input_shape, include_last_step=False, use_label=True) + + torch.cuda.set_device(device_ids[0]) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # device = f"cuda:{device_ids[0]}" + + # Continue train + if resume_from != "": + print("Resume from: ", resume_from) + net.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + if GPUS == 1: + net_par = net.cuda() + else: + net_par = torch.nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_par.parameters(), lr=0.00005) + + if resume_from != "": + optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_par.train() + + print("start train.") + train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, + epochs=epochs[0], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_1_final.trch", + ) + + net_2 = make_network(input_shape, include_last_step=True, use_label=True) + + net_2.regis_net.netPhi.load_state_dict(net.regis_net.state_dict()) + + # Continue train + # if resume_from != "": + # print("Resume from: ", resume_from) + # net_2.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + del net + del net_par + del optimizer + + if GPUS == 1: + net_2_par = net_2.cuda() + else: + net_2_par = torch.nn.DataParallel(net_2, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_2_par.parameters(), lr=0.00005) + + # if resume_from != "": + # optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_2_par.train() + + # We're being weird by training two networks in one script. This hack keeps + # the second training from overwriting the outputs of the first. + footsteps.output_dir_impl = footsteps.output_dir + "2nd_step/" + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train(net_2_par, optimizer, data_loader, val_data_loader, unwrapped_net=net_2, epochs=epochs[1], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net_2.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_2_final.trch", + ) + + return net_2 + +def finetune(net, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period): + if GPUS == 1: + net_par = net.cuda() + else: + net_par = torch.nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_par.parameters(), lr=0.00005) + + net_par.train() + + footsteps.output_dir_impl = footsteps.output_dir.split("2nd_step/")[0] + "finetune/" + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, epochs=epochs, eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/finetune_final.trch", + ) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--resume_from", required=False, default="") + args = parser.parse_args() + resume_from = args.resume_from + + import footsteps + footsteps.initialize(output_root=EXP_DIR) + + dataset, weights = get_multi_training_set() + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE*GPUS, + num_workers=4, + drop_last=True, + sampler=torch.utils.data.WeightedRandomSampler(weights, DATA_NUM) + ) + + val_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=4, + drop_last=True, + ) + print("Finish data loading...") + + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + print("Start training...") + net = train_two_stage(input_shape, dataloader, val_dataloader, GPUS, [801,201], 20, 20, resume_from) + + del dataloader, val_dataloader, dataset, weights + + print("Start finetuning...") + + fine_dataset, fine_weights = get_multi_finetuning_set() + + fine_dataloader = DataLoader( + fine_dataset, + batch_size=BATCH_SIZE*GPUS, + num_workers=4, + drop_last=True, + sampler=torch.utils.data.WeightedRandomSampler(fine_weights, DATA_NUM) + ) + + fine_val_dataloader = DataLoader( + fine_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=4, + drop_last=True, + ) + + print("Finish data loading...") + + print("Start finetuning...") + finetune(net, fine_dataloader, fine_val_dataloader, GPUS, 100, 20, 20) \ No newline at end of file