From 0892f461febce439689884aad86681964a0e43fb Mon Sep 17 00:00:00 2001 From: Lin Tian Date: Thu, 22 Aug 2024 09:19:01 -0400 Subject: [PATCH 1/4] Add the training scripts. --- training/dataset.py | 291 ++++++++++++++++++++++++++++ training/train.py | 458 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 749 insertions(+) create mode 100644 training/dataset.py create mode 100644 training/train.py diff --git a/training/dataset.py b/training/dataset.py new file mode 100644 index 0000000..ff948a0 --- /dev/null +++ b/training/dataset.py @@ -0,0 +1,291 @@ +import torch +import os +import torch.nn.functional as F +import random +import numpy as np +import itk +import glob +import SimpleITK as sitk +from tqdm import tqdm + +DATASET_DIR = "./data/uniGradICON/" + +class COPDDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + scale="2xdown", + data_path=f"{DATASET_DIR}/half_res_preprocessed_transposed_SI", + ROI_only=False, + data_num=-1, + desired_shape=None, + device="cpu" + ): + if phase == "debug": + phase = "train" + self.imgs = torch.load( + f"{data_path}/lungs_{phase}_{scale}_scaled", map_location="cpu" + ) + if data_num <= 0 or data_num > len(self.imgs): + self.data_num = len(self.imgs) + else: + self.data_num = data_num + self.imgs = self.imgs[: self.data_num] + + # Normalize to [0,1] + print("Processing COPD data.") + if ROI_only: + segs = torch.load( + f"{data_path}/lungs_seg_{phase}_{scale}_scaled", map_location="cpu" + )[: self.data_num] + self.imgs = list(map(lambda x: (self.process(x[0][0], desired_shape, device, x[1][0])[0],self.process(x[0][1], desired_shape, device, x[1][1])[0]), tqdm(zip(self.imgs, segs)))) + else: + self.imgs = list(map(lambda x: (self.process(x[0], desired_shape, device)[0],self.process(x[1], desired_shape, device)[0]), tqdm(self.imgs))) + + + def process(self, img, desired_shape=None, device="cpu", seg=None): + img = img.to(device) + im_min, im_max = torch.min(img), torch.max(img) + img = (img-im_min) / (im_max-im_min) + if seg is not None: + seg = seg.to(device) + img = (img * seg).float() + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + img_a, img_b = self.imgs[idx] + return img_a, img_b + + +class OAIDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + scale="2xdownsample", + data_path=f"{DATASET_DIR}/OAI", + data_num=1000, + desired_shape=None, + device="cpu" + ): + if phase == "debug": + phase = "train" + if phase == "test": + phase = "train" + print( + "WARNING: There is no validation set for OAI. Using train data for test set." + ) + self.imgs = torch.load( + f"{data_path}/knees_big_{scale}_train_set", map_location="cpu" + ) + + print("Processing OAI data.") + self.imgs = list(map(lambda x: self.process(x, desired_shape, device)[0], tqdm(self.imgs))) + + self.img_num = len(self.imgs) + + self.data_num = data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[idx_a] + img_b = self.imgs[idx_b] + return img_a, img_b + + +class HCPDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + scale="2xdown", + data_path=f"{DATASET_DIR}/HCP", + data_num=1000, + desired_shape=None, + device="cpu" + ): + if phase == "debug": + phase = "train" + if phase == "test": + phase = "train" + print( + "WARNING: There is no validation set for OAI. Using train data for test set." + ) + self.imgs = torch.load( + f"{data_path}/brain_train_{scale}_scaled", map_location="cpu" + ) + print("Processing HCP data.") + self.imgs = list(map(lambda x: self.process(x, desired_shape, device)[0], tqdm(self.imgs))) + + self.img_num = len(self.imgs) + + self.data_num = data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[idx_a] + img_b = self.imgs[idx_b] + return img_a, img_b + + +class L2rAbdomenDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/AbdomenCTCT", + data_num=1000, + desired_shape=None, + device="cpu" + ): + cases = list(map(lambda x: os.path.join(f"{data_path}/imagesTr", x), os.listdir(f"{data_path}/imagesTr/"))) + self.imgs = [] + print("Processing L2R Abdomen data.") + for i in tqdm(range(len(cases))): + case_path = cases[i] + self.imgs.append(self.process(torch.tensor(np.asarray(itk.imread(case_path)))[None, None], desired_shape, device)[0]) + + self.img_num = len(self.imgs) + + self.data_num = data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[idx_a] + img_b = self.imgs[idx_b] + return img_a, img_b + + +class L2rThoraxCBCTDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/ThoraxCBCT", + data_num=1000, + desired_shape=None, + device="cpu" + ): + import json + with open(f"{data_path}/ThoraxCBCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + cases = [[f"{data_path}/{c['moving']}", f"{data_path}/{c['fixed']}"] for c in data_info["training_paired_images"]] + self.imgs = [] + print("Processing L2R ThoraxCBCT data.") + for i in tqdm(range(len(cases))): + moving_path, fixed_path = cases[i] + self.imgs.append( + ( + self.process(torch.tensor(np.asarray(itk.imread(moving_path)))[None, None], desired_shape, device)[0], + self.process(torch.tensor(np.asarray(itk.imread(fixed_path)))[None, None], desired_shape, device)[0] + )) + + if data_num < len(self.imgs): + self.imgs = self.imgs[:data_num] + self.data_num = len(self.imgs) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + img_a, img_b = self.imgs[idx] + return img_a, img_b + + +class ACDCDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/ACDC", + desired_shape=None + ): + self.imgs = [] + training_files = sorted(glob(f"{data_path}/database/training/*/*4d.nii.gz")) + + for file in training_files: + self.imgs.append( + self.process( + torch.tensor( + sitk.GetArrayFromImage(sitk.ReadImage(file, sitk.sitkFloat32)) + )[None], + desired_shape) + ) + + self.img_num = len(self.imgs) + + + def process(self, img, desired_shape=None): + img = img / torch.amax(img, dim=(1,2,3), keepdim=True) + + # Pad the image + pad_size = (img.shape[2] - img.shape[1]) // 2 + img = torch.pad(img, (0, 0, 0, 0, pad_size, pad_size), "constant", 0) + + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img + + def __len__(self): + return self.img_num + + def __getitem__(self, idx): + img = self.imgs[idx] + img_count = img.shape[0] + idx_a = random.randint(0, img_count - 1) + idx_b = random.randint(0, img_count - 1) + img_a = img[idx_a] + img_b = img[idx_b] + return img_a, img_b + + +if __name__ == "__main__": + from torch.utils.data import DataLoader + datasets = [COPDDataset, OAIDataset, HCPDataset, L2rAbdomenDataset] + + for dataset in datasets: + data = dataset(desired_shape=(64, 64, 64)) + data = DataLoader(data, batch_size=3) + print(next(iter(data))[0].shape) diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..9595ba8 --- /dev/null +++ b/training/train.py @@ -0,0 +1,458 @@ +import os +import random +from datetime import datetime + +from tqdm import tqdm +import numpy as np +import torch +import torch.nn.functional as F +from dataset import COPDDataset, HCPDataset, OAIDataset, L2rAbdomenDataset +from torch.utils.data import ConcatDataset, DataLoader + +import icon_registration as icon +import icon_registration.network_wrappers as network_wrappers +import icon_registration.networks as networks +from icon_registration import config +from icon_registration.losses import ICONLoss, to_floats +from icon_registration.mermaidlite import compute_warped_image_multiNC + + +def write_stats(writer, stats: ICONLoss, ite, prefix=""): + for k, v in to_floats(stats)._asdict().items(): + writer.add_scalar(f"{prefix}{k}", v, ite) + +input_shape = [1, 1, 175, 175, 175] + +BATCH_SIZE= 1 +device_ids = [1, 0, 2, 3] +GPUS = len(device_ids) +EXP_DIR = "./results/unigradicon/" + +class GradientICONSparse(network_wrappers.RegistrationModule): + def __init__(self, network, similarity, lmbda): + + super().__init__() + + self.regis_net = network + self.lmbda = lmbda + self.similarity = similarity + + def forward(self, image_A, image_B): + + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + self.phi_AB = self.regis_net(image_A, image_B) + self.phi_BA = self.regis_net(image_B, image_A) + + self.phi_AB_vectorfield = self.phi_AB(self.identity_map) + self.phi_BA_vectorfield = self.phi_BA(self.identity_map) + + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + + if getattr(self.similarity, "isInterpolated", False): + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) + if len(self.input_shape) - 2 == 3: + inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 + elif len(self.input_shape) - 2 == 2: + inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 + else: + inbounds_tag[:, :, 1:-1] = 1.0 + else: + inbounds_tag = None + + self.warped_image_A = compute_warped_image_multiNC( + torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A, + self.phi_AB_vectorfield, + self.spacing, + 1, + zero_boundary=True + ) + self.warped_image_B = compute_warped_image_multiNC( + torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B, + self.phi_BA_vectorfield, + self.spacing, + 1, + zero_boundary=True + ) + + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + self.similarity(self.warped_image_B, image_A) + + device = image_A.device + if len(self.input_shape) - 2 == 3: + Iepsilon = ( + self.identity_map + + 2 * torch.randn(*self.identity_map.shape).to(device) + / self.identity_map.shape[-1] + )[:, :, ::2, ::2, ::2] + elif len(self.input_shape) - 2 == 2: + Iepsilon = ( + self.identity_map + + 2 * torch.randn(*self.identity_map.shape).to(device) + / self.identity_map.shape[-1] + )[:, :, ::2, ::2] + + # compute squared Frobenius of Jacobian of icon error + + direction_losses = [] + + approximate_Iepsilon = self.phi_AB(self.phi_BA(Iepsilon)) + + inverse_consistency_error = Iepsilon - approximate_Iepsilon + + delta = 0.001 + + if len(self.identity_map.shape) == 4: + dx = torch.tensor([[[[delta]], [[0.0]]]]).to(device) + dy = torch.tensor([[[[0.0]], [[delta]]]]).to(device) + direction_vectors = (dx, dy) + + elif len(self.identity_map.shape) == 5: + dx = torch.tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(device) + dy = torch.tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(device) + dz = torch.tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(device) + direction_vectors = (dx, dy, dz) + elif len(self.identity_map.shape) == 3: + dx = torch.tensor([[[delta]]]).to(device) + direction_vectors = (dx,) + + for d in direction_vectors: + approximate_Iepsilon_d = self.phi_AB(self.phi_BA(Iepsilon + d)) + inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d + grad_d_icon_error = ( + inverse_consistency_error - inverse_consistency_error_d + ) / delta + direction_losses.append(torch.mean(grad_d_icon_error**2)) + + inverse_consistency_loss = sum(direction_losses) + + all_loss = self.lmbda * inverse_consistency_loss + similarity_loss + + transform_magnitude = torch.mean( + (self.identity_map - self.phi_AB_vectorfield) ** 2 + ) + return icon.losses.ICONLoss( + all_loss, + inverse_consistency_loss, + similarity_loss, + transform_magnitude, + icon.losses.flips(self.phi_BA_vectorfield), + ) + + def clean(self): + del self.phi_AB, self.phi_BA, self.phi_AB_vectorfield, self.phi_BA_vectorfield, self.warped_image_A, self.warped_image_B + + +def get_dataset(): + data_num = 1000 + return ConcatDataset( + ( + COPDDataset(desired_shape=input_shape[2:], device=device_ids[0], ROI_only=True, data_num=data_num), + OAIDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=data_num), + HCPDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=data_num), + L2rAbdomenDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=data_num) + ) + ) + +def augment(image_A, image_B): + device = image_A.device + identity_list = [] + for i in range(image_A.shape[0]): + identity = torch.tensor([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], device=device) + idxs = set((0, 1, 2)) + for j in range(3): + k = random.choice(list(idxs)) + idxs.remove(k) + identity[0, j, k] = 1 + identity = identity * (torch.randint_like(identity, 0, 2, device=device) * 2 - 1) + identity_list.append(identity) + + identity = torch.cat(identity_list) + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_A.shape[1] > 1: + # Then we have segmentations + warped_A = F.grid_sample(image_A[:, :1], forward_grid, padding_mode='border') + warped_A_seg = F.grid_sample(image_A[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_A = torch.cat([warped_A, warped_A_seg], axis=1) + else: + warped_A = F.grid_sample(image_A, forward_grid, padding_mode='border') + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_B.shape[1] > 1: + # Then we have segmentations + warped_B = F.grid_sample(image_B[:, :1], forward_grid, padding_mode='border') + warped_B_seg = F.grid_sample(image_B[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_B = torch.cat([warped_B, warped_B_seg], axis=1) + else: + warped_B = F.grid_sample(image_B, forward_grid, padding_mode='border') + + return warped_A, warped_B + +def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5)): + dimension = len(input_shape) - 2 + inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)) + + for _ in range(2): + inner_net = icon.TwoStepRegistration( + icon.DownsampleRegistration(inner_net, dimension=dimension), + icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)) + ) + if include_last_step: + inner_net = icon.TwoStepRegistration(inner_net, icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension))) + + net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda) + net.assign_identity_map(input_shape) + return net + +def train_kernel(optimizer, net, moving_image, fixed_image, writer, ite): + optimizer.zero_grad() + loss_object = net(moving_image, fixed_image) + loss = torch.mean(loss_object.all_loss) + loss.backward() + optimizer.step() + # print(to_floats(loss_object)) + write_stats(writer, loss_object, ite, prefix="train/") + +def train( + net, + optimizer, + data_loader, + val_data_loader, + epochs=200, + eval_period=-1, + save_period=-1, + step_callback=(lambda net: None), + unwrapped_net=None, + data_augmenter=None, +): + """A training function intended for long running experiments, with tensorboard logging + and model checkpoints. Use for medical registration training + """ + import footsteps + from torch.utils.tensorboard import SummaryWriter + + if unwrapped_net is None: + unwrapped_net = net + + loss_curve = [] + writer = SummaryWriter( + footsteps.output_dir + "/logs/" + datetime.now().strftime("%Y%m%d-%H%M%S"), + flush_secs=30, + ) + + iteration = 0 + for epoch in tqdm(range(epochs)): + for moving_image, fixed_image in data_loader: + moving_image, fixed_image = moving_image.cuda(), fixed_image.cuda() + if data_augmenter is not None: + with torch.no_grad(): + moving_image, fixed_image = data_augmenter(moving_image, fixed_image) + train_kernel(optimizer, net, moving_image, fixed_image, + writer, iteration) + iteration += 1 + + step_callback(unwrapped_net) + + + if epoch % save_period == 0: + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + + if epoch % eval_period == 0: + visualization_moving, visualization_fixed = next(iter(val_data_loader)) + visualization_moving, visualization_fixed = visualization_moving[:, :1].cuda(), visualization_fixed[:, :1].cuda() + unwrapped_net.eval() + warped = [] + with torch.no_grad(): + eval_loss = unwrapped_net(visualization_moving, visualization_fixed) + write_stats(writer, eval_loss, epoch, prefix="val/") + warped = unwrapped_net.warped_image_A.cpu() + del eval_loss + unwrapped_net.clean() + unwrapped_net.train() + + def render(im): + if len(im.shape) == 5: + im = im[:, :, :, im.shape[3] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + writer.add_images( + "moving_image", render(visualization_moving[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "fixed_image", render(visualization_fixed[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "warped_moving_image", + render(warped), + epoch, + dataformats="NCHW", + ) + writer.add_images( + "difference", + render(torch.clip((warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1)), + epoch, + dataformats="NCHW", + ) + + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + +def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period, resume_from): + + net = make_network(input_shape, include_last_step=False) + + torch.cuda.set_device(device_ids[0]) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # device = f"cuda:{device_ids[0]}" + + # Continue train + if resume_from != "": + print("Resume from: ", resume_from) + net.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + if GPUS == 1: + net_par = net.cuda() + else: + net_par = torch.nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_par.parameters(), lr=0.00005) + + if resume_from != "": + optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_par.train() + + print("start train.") + train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, + epochs=epochs[0], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_1_final.trch", + ) + + net_2 = make_network(input_shape, include_last_step=True) + + net_2.regis_net.netPhi.load_state_dict(net.regis_net.state_dict()) + + # Continue train + # if resume_from != "": + # print("Resume from: ", resume_from) + # net_2.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + del net + del net_par + del optimizer + + if GPUS == 1: + net_2_par = net_2.cuda() + else: + net_2_par = torch.nn.DataParallel(net_2, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_2_par.parameters(), lr=0.00005) + + # if resume_from != "": + # optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_2_par.train() + + # We're being weird by training two networks in one script. This hack keeps + # the second training from overwriting the outputs of the first. + footsteps.output_dir_impl = footsteps.output_dir + "2nd_step/" + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train(net_2_par, optimizer, data_loader, val_data_loader, unwrapped_net=net_2, epochs=epochs[1], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net_2.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_2_final.trch", + ) + +def finetune_execute(model, image_A, image_B, steps, lr=2e-5): + # state_dict = model.state_dict() + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + for _ in range(steps): + optimizer.zero_grad() + loss_tuple = model(image_A, image_B) + print(loss_tuple) + loss_tuple[0].backward() + optimizer.step() + with torch.no_grad(): + loss = model(image_A, image_B) + # model.load_state_dict(state_dict) + model.eval() + return loss + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--resume_from", required=False, default="") + args = parser.parse_args() + resume_from = args.resume_from + + import footsteps + footsteps.initialize(output_root=EXP_DIR) + + dataset = get_dataset() + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE*GPUS, + shuffle=True, + num_workers=4, + drop_last=True, + ) + val_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=4, + drop_last=True, + ) + print("Finish data loading...") + + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train_two_stage(input_shape, dataloader, val_dataloader, GPUS, [801,201], 20, 20, resume_from) \ No newline at end of file From 6a2ae8b28d4dc2cf95941ce47e9b222e2f6d3384 Mon Sep 17 00:00:00 2001 From: Lin Tian Date: Thu, 22 Aug 2024 09:44:21 -0400 Subject: [PATCH 2/4] Remove the duplicate definition of the model. --- training/train.py | 133 +--------------------------------------------- 1 file changed, 2 insertions(+), 131 deletions(-) diff --git a/training/train.py b/training/train.py index 9595ba8..f64fb07 100644 --- a/training/train.py +++ b/training/train.py @@ -3,19 +3,16 @@ from datetime import datetime from tqdm import tqdm -import numpy as np import torch import torch.nn.functional as F from dataset import COPDDataset, HCPDataset, OAIDataset, L2rAbdomenDataset from torch.utils.data import ConcatDataset, DataLoader import icon_registration as icon -import icon_registration.network_wrappers as network_wrappers import icon_registration.networks as networks -from icon_registration import config from icon_registration.losses import ICONLoss, to_floats -from icon_registration.mermaidlite import compute_warped_image_multiNC +from unigradicon import GradientICONSparse def write_stats(writer, stats: ICONLoss, ite, prefix=""): for k, v in to_floats(stats)._asdict().items(): @@ -23,137 +20,11 @@ def write_stats(writer, stats: ICONLoss, ite, prefix=""): input_shape = [1, 1, 175, 175, 175] -BATCH_SIZE= 1 +BATCH_SIZE= 4 device_ids = [1, 0, 2, 3] GPUS = len(device_ids) EXP_DIR = "./results/unigradicon/" -class GradientICONSparse(network_wrappers.RegistrationModule): - def __init__(self, network, similarity, lmbda): - - super().__init__() - - self.regis_net = network - self.lmbda = lmbda - self.similarity = similarity - - def forward(self, image_A, image_B): - - assert self.identity_map.shape[2:] == image_A.shape[2:] - assert self.identity_map.shape[2:] == image_B.shape[2:] - - # Tag used elsewhere for optimization. - # Must be set at beginning of forward b/c not preserved by .cuda() etc - self.identity_map.isIdentity = True - - self.phi_AB = self.regis_net(image_A, image_B) - self.phi_BA = self.regis_net(image_B, image_A) - - self.phi_AB_vectorfield = self.phi_AB(self.identity_map) - self.phi_BA_vectorfield = self.phi_BA(self.identity_map) - - # tag images during warping so that the similarity measure - # can use information about whether a sample is interpolated - # or extrapolated - - if getattr(self.similarity, "isInterpolated", False): - # tag images during warping so that the similarity measure - # can use information about whether a sample is interpolated - # or extrapolated - inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) - if len(self.input_shape) - 2 == 3: - inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 - elif len(self.input_shape) - 2 == 2: - inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 - else: - inbounds_tag[:, :, 1:-1] = 1.0 - else: - inbounds_tag = None - - self.warped_image_A = compute_warped_image_multiNC( - torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A, - self.phi_AB_vectorfield, - self.spacing, - 1, - zero_boundary=True - ) - self.warped_image_B = compute_warped_image_multiNC( - torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B, - self.phi_BA_vectorfield, - self.spacing, - 1, - zero_boundary=True - ) - - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) + self.similarity(self.warped_image_B, image_A) - - device = image_A.device - if len(self.input_shape) - 2 == 3: - Iepsilon = ( - self.identity_map - + 2 * torch.randn(*self.identity_map.shape).to(device) - / self.identity_map.shape[-1] - )[:, :, ::2, ::2, ::2] - elif len(self.input_shape) - 2 == 2: - Iepsilon = ( - self.identity_map - + 2 * torch.randn(*self.identity_map.shape).to(device) - / self.identity_map.shape[-1] - )[:, :, ::2, ::2] - - # compute squared Frobenius of Jacobian of icon error - - direction_losses = [] - - approximate_Iepsilon = self.phi_AB(self.phi_BA(Iepsilon)) - - inverse_consistency_error = Iepsilon - approximate_Iepsilon - - delta = 0.001 - - if len(self.identity_map.shape) == 4: - dx = torch.tensor([[[[delta]], [[0.0]]]]).to(device) - dy = torch.tensor([[[[0.0]], [[delta]]]]).to(device) - direction_vectors = (dx, dy) - - elif len(self.identity_map.shape) == 5: - dx = torch.tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(device) - dy = torch.tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(device) - dz = torch.tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(device) - direction_vectors = (dx, dy, dz) - elif len(self.identity_map.shape) == 3: - dx = torch.tensor([[[delta]]]).to(device) - direction_vectors = (dx,) - - for d in direction_vectors: - approximate_Iepsilon_d = self.phi_AB(self.phi_BA(Iepsilon + d)) - inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d - grad_d_icon_error = ( - inverse_consistency_error - inverse_consistency_error_d - ) / delta - direction_losses.append(torch.mean(grad_d_icon_error**2)) - - inverse_consistency_loss = sum(direction_losses) - - all_loss = self.lmbda * inverse_consistency_loss + similarity_loss - - transform_magnitude = torch.mean( - (self.identity_map - self.phi_AB_vectorfield) ** 2 - ) - return icon.losses.ICONLoss( - all_loss, - inverse_consistency_loss, - similarity_loss, - transform_magnitude, - icon.losses.flips(self.phi_BA_vectorfield), - ) - - def clean(self): - del self.phi_AB, self.phi_BA, self.phi_AB_vectorfield, self.phi_BA_vectorfield, self.warped_image_A, self.warped_image_B - - def get_dataset(): data_num = 1000 return ConcatDataset( From f78e27ff698cbeb7c2a30eb3842050b1422103dd Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 28 Oct 2024 02:30:37 -0400 Subject: [PATCH 3/4] Add the multigradicon training scripts. --- src/unigradicon/__init__.py | 55 ++- training/dataset_multi.py | 813 ++++++++++++++++++++++++++++++++++++ training/train.py | 34 +- training/train_multi.py | 393 +++++++++++++++++ 4 files changed, 1244 insertions(+), 51 deletions(-) create mode 100644 training/dataset_multi.py create mode 100644 training/train_multi.py diff --git a/src/unigradicon/__init__.py b/src/unigradicon/__init__.py index d284ee5..3395c82 100644 --- a/src/unigradicon/__init__.py +++ b/src/unigradicon/__init__.py @@ -15,23 +15,27 @@ from icon_registration.mermaidlite import compute_warped_image_multiNC import icon_registration.itk_wrapper - - input_shape = [1, 1, 175, 175, 175] class GradientICONSparse(network_wrappers.RegistrationModule): - def __init__(self, network, similarity, lmbda): + def __init__(self, network, similarity, lmbda, use_label=False): super().__init__() self.regis_net = network self.lmbda = lmbda self.similarity = similarity + self.use_label = use_label - def forward(self, image_A, image_B): + def forward(self, image_A, image_B, label_A=None, label_B=None): assert self.identity_map.shape[2:] == image_A.shape[2:] assert self.identity_map.shape[2:] == image_B.shape[2:] + if self.use_label: + label_A = image_A if label_A is None else label_A + label_B = image_B if label_B is None else label_B + assert self.identity_map.shape[2:] == label_A.shape[2:] + assert self.identity_map.shape[2:] == label_B.shape[2:] # Tag used elsewhere for optimization. # Must be set at beginning of forward b/c not preserved by .cuda() etc @@ -75,10 +79,29 @@ def forward(self, image_A, image_B): 1, zero_boundary=True ) - - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) + self.similarity(self.warped_image_B, image_A) + + if self.use_label: + self.warped_label_A = compute_warped_image_multiNC( + torch.cat([label_A, inbounds_tag], axis=1) if inbounds_tag is not None else label_A, + self.phi_AB_vectorfield, + self.spacing, + 1, + ) + + self.warped_label_B = compute_warped_image_multiNC( + torch.cat([label_B, inbounds_tag], axis=1) if inbounds_tag is not None else label_B, + self.phi_BA_vectorfield, + self.spacing, + 1, + ) + + similarity_loss = self.similarity( + self.warped_label_A, label_B + ) + self.similarity(self.warped_label_B, label_A) + else: + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + self.similarity(self.warped_image_B, image_A) if len(self.input_shape) - 2 == 3: Iepsilon = ( @@ -142,8 +165,10 @@ def forward(self, image_A, image_B): def clean(self): del self.phi_AB, self.phi_BA, self.phi_AB_vectorfield, self.phi_BA_vectorfield, self.warped_image_A, self.warped_image_B + if self.use_label: + del self.warped_label_A, self.warped_label_B -def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5)): +def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5), use_label=False): dimension = len(input_shape) - 2 inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)) @@ -155,7 +180,7 @@ def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.L if include_last_step: inner_net = icon.TwoStepRegistration(inner_net, icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension))) - net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda) + net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda, use_label=use_label) net.assign_identity_map(input_shape) return net @@ -237,7 +262,7 @@ def preprocess(image, modality="ct", segmentation=None): min_ = -1000 max_ = 1000 image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image) - image = itk.clamp_image_filter(image, Bounds=(-1000, 1000)) + image = itk.clamp_image_filter(image, Bounds=(min_, max_)) elif modality == "mri": image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image) min_, _ = itk.image_intensity_min_max(image) @@ -358,10 +383,4 @@ def warp_command(): use_reference_image=True, reference_image=fixed ) - itk.imwrite(warped_moving_image, args.warped_moving_out) - - - - - - + itk.imwrite(warped_moving_image, args.warped_moving_out) \ No newline at end of file diff --git a/training/dataset_multi.py b/training/dataset_multi.py new file mode 100644 index 0000000..08cf330 --- /dev/null +++ b/training/dataset_multi.py @@ -0,0 +1,813 @@ +import torch +import os +import torch.nn.functional as F +import random +import numpy as np +import itk +import glob +import SimpleITK as sitk +from tqdm import tqdm +import blosc +import json + +blosc.set_nthreads(1) +DATASET_DIR = "./data/multiGradICON/" + +class COPDDataset(torch.utils.data.Dataset): + def __init__( + self, + scale="2xdown", + data_path=f"{DATASET_DIR}/half_res_preprocessed_transposed_SI", + ROI_only=False, + data_num=-1, + desired_shape=None, + device="cpu", + return_labels=False + ): + self.imgs = torch.load(f"{data_path}/lungs_train_{scale}_scaled", map_location="cpu") + self.data_num = data_num + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.modalities = ['ct'] + self.anatomies = ['lung'] + self.region_num = 1 + + print("Processing COPD data.") + if ROI_only: + segs = torch.load( + f"{data_path}/lungs_seg_train_{scale}_scaled", map_location="cpu" + )[: self.data_num] + self.imgs = list(map(lambda x: (self.pack_and_process_image(x[0][0], x[1][0]), self.pack_and_process_image(x[0][1], x[1][1])), tqdm(zip(self.imgs, segs)))) + else: + self.imgs = list(map(lambda x: (self.pack_and_process_image(x[0]), self.pack_and_process_image(x[1])), tqdm(self.imgs))) + + def pack_and_process_image(self, img, seg=None): + processed_image = self.process(img, self.desired_shape, self.device, seg)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu", seg=None): + img = img.to(device) + im_min, im_max = torch.min(img), torch.max(img) + img = (img-im_min) / (im_max-im_min) + if seg is not None: + seg = seg.to(device) + img = (img * seg).float() + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + img_a, img_b = random.choice(self.imgs) + + if self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(img_a), blosc.unpack_array(img_b) + else: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + +class BratsRegDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/BraTS-Reg/BraTSReg_Training_Data_v3/", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random' + ): + + folders = sorted(glob.glob(data_path + '*/')) + self.pre_images, self.post_images = {'t1': [], 't1ce': [], 't2': [], 'flair': []}, {'t1': [], 't1ce': [], 't2': [], 'flair': []} + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['brain'] + self.region_num = 1 + + for folder in tqdm(folders): + t1 = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_t1.nii.gz')[0]) + t1ce = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_t1ce.nii.gz')[0]) + t2 = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_t2.nii.gz')[0]) + flair = itk.imread(glob.glob(folder + 'BraTSReg_*_00_*_flair.nii.gz')[0]) + + self.pre_images['t1'].append(self.pack_and_process_image(t1)) + self.pre_images['t1ce'].append(self.pack_and_process_image(t1ce)) + self.pre_images['t2'].append(self.pack_and_process_image(t2)) + self.pre_images['flair'].append(self.pack_and_process_image(flair)) + + t1 = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_t1.nii.gz')[0]) + t1ce = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_t1ce.nii.gz')[0]) + t2 = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_t2.nii.gz')[0]) + flair = itk.imread(glob.glob(folder + 'BraTSReg_*_01_*_flair.nii.gz')[0]) + + self.post_images['t1'].append(self.pack_and_process_image(t1)) + self.post_images['t1ce'].append(self.pack_and_process_image(t1ce)) + self.post_images['t2'].append(self.pack_and_process_image(t2)) + self.post_images['flair'].append(self.pack_and_process_image(flair)) + + self.data_num = data_num + self.image_num = len(self.pre_images['t1']) + self.modalities = list(self.pre_images.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx1 = random.randint(0, self.image_num-1) + + img_a = self.pre_images[random.choice(self.modalities)][idx1] + img_b = self.post_images[random.choice(self.modalities)][idx1] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label_a = self.pre_images[random.choice(self.modalities)][idx1] + label_b = self.post_images[random.choice(self.modalities)][idx1] + else: + modality = random.choice(self.modalities) + label_a = self.pre_images[modality][idx1] + label_b = self.post_images[modality][idx1] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label_a), blosc.unpack_array(label_b) + +class L2rAbdomenDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/AbdomenCTCT", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random', + augmentation = True + ): + cases = list(map(lambda x: os.path.join(f"{data_path}/imagesTr", x), os.listdir(f"{data_path}/imagesTr/"))) + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['abdomen'] + self.region_num = 1 + + self.imgs = {'ct' : []} + if augmentation: + self.imgs['1-ct'] = [] + + print("Processing L2R Abdomen data.") + for i in tqdm(range(len(cases))): + case_path = cases[i] + self.imgs['ct'].append(self.pack_and_process_image(case_path)) + if augmentation: + self.imgs['1-ct'].append(self.pack_and_process_image(case_path, invert=True)) + + self.img_num = len(self.imgs) + self.data_num = data_num + self.modalities = list(self.imgs.keys()) + + def pack_and_process_image(self, case_path, invert=False): + processed_image = self.process(torch.tensor(np.asarray(itk.imread(case_path)))[None, None], self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + if invert: + array_image = 1 - array_image + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + img_a = self.imgs[random.choice(self.modalities)][idx_a] + img_b = self.imgs[random.choice(self.modalities)][idx_b] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label_a = self.imgs[random.choice(self.modalities)][idx_a] + label_b = self.imgs[random.choice(self.modalities)][idx_b] + else: + modality = random.choice(self.modalities) + label_a = self.imgs[modality][idx_a] + label_b = self.imgs[modality][idx_b] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label_a), blosc.unpack_array(label_b) + +class HCPDataset(torch.utils.data.Dataset): + def __init__( + self, + scale="2xdown", + data_path=f"{DATASET_DIR}/ICON_brain_preprocessed_data", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random' + ): + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['brain'] + self.region_num = 1 + + imgsT1 = torch.load( + f"{data_path}/brain_train_{scale}_scaled_T1", map_location="cpu" + ) + imgsT2 = torch.load( + f"{data_path}/brain_train_{scale}_scaled_T2", map_location="cpu" + ) + imgsT1 = list(map(lambda x: self.pack_and_process_image(x), tqdm(imgsT1))) + imgsT2 = list(map(lambda x: self.pack_and_process_image(x), tqdm(imgsT2))) + + self.imgs = {'T1': imgsT1, 'T2': imgsT2} + + self.img_num = len(imgsT1) + self.data_num = data_num + self.modalities = list(self.imgs.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(image, self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + idx_a = random.randint(0, self.img_num - 1) + idx_b = random.randint(0, self.img_num - 1) + + img_a = self.imgs[random.choice(self.modalities)][idx_a] + img_b = self.imgs[random.choice(self.modalities)][idx_b] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label_a = self.imgs[random.choice(self.modalities)][idx_a] + label_b = self.imgs[random.choice(self.modalities)][idx_b] + else: + modality = random.choice(self.modalities) + label_a = self.imgs[modality][idx_a] + label_b = self.imgs[modality][idx_b] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label_a), blosc.unpack_array(label_b) + +class ABCDFAMDDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + data_path=f"{DATASET_DIR}/dti_scalars", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + randomization = 'random' + ): + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.images = {} + self.anatomies = ['brain'] + self.region_num = 1 + + md_files = sorted(glob.glob(f'{data_path}/md/' + '*.nii.gz')) + fa_files = sorted(glob.glob(f'{data_path}/fa/' + '*.nii.gz')) + + if phase == "train": + fa_files = fa_files[10:] + md_files = md_files[10:] + elif phase == "val": + fa_files = fa_files[:10] + md_files = md_files[:10] + + fa_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in fa_files} + md_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in md_files} + + for fa_id in tqdm(fa_ids): + if fa_id not in md_ids: + continue + + fa = itk.imread(fa_ids[fa_id]) + md = itk.imread(md_ids[fa_id]) + self.images[fa_id] = {'FA': self.pack_and_process_image(fa), 'MD': self.pack_and_process_image(md)} + + self.data_num = data_num + self.image_ids = list(self.images.keys()) + self.modalities = ['FA', 'MD'] + + def pack_and_process_image(self, image): + processed_image = self.process(self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0]) + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + index1 = random.choice(self.image_ids) + image1 = self.images[index1][random.choice(self.modalities)] + + index2 = random.choice(self.image_ids) + image2 = self.images[index2][random.choice(self.modalities)] + + if not self.return_labels: + return blosc.unpack_array(image1), blosc.unpack_array(image2) + + if self.randomization == 'random': + label1 = self.images[index1][random.choice(self.modalities)] + label2 = self.images[index2][random.choice(self.modalities)] + else: + modality = random.choice(self.modalities) + label1 = self.images[index1][modality] + label2 = self.images[index2][modality] + + return blosc.unpack_array(image1), blosc.unpack_array(image2), blosc.unpack_array(label1), blosc.unpack_array(label2) + +class ABCDDataset(torch.utils.data.Dataset): + def __init__( + self, + phase="train", + data_path=f"{DATASET_DIR}", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + ): + + md_path = f'{data_path}/dti_scalars/md/' + md_files = sorted(glob.glob(md_path + '*.nii.gz')) + + fa_path = f'{data_path}/dti_scalars/fa/' + fa_files = sorted(glob.glob(fa_path + '*.nii.gz')) + + mri_path = f'{data_path}/structural_mri/' + mri_files = sorted(glob.glob(mri_path + '*_oriented_stripped.nii.gz')) + + if phase == "train": + fa_files = fa_files[10:] + md_files = md_files[10:] + mri_files = mri_files[10:] + elif phase == "val": + fa_files = fa_files[:10] + md_files = md_files[:10] + mri_files = mri_files[:10] + + fa_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in fa_files} + md_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in md_files} + mri_ids = {x.split('/')[-1].split('.')[0].split('_')[0].split('-')[1] : x for x in mri_files} + + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.images = {'FA': [], 'MD': [], 'T1': [], 'T2': []} + self.anatomies = ['brain'] + self.region_num = 1 + + for id in tqdm(fa_ids): + fa = itk.imread(fa_ids[id]) + self.images['FA'].append(self.pack_and_process_image(fa)) + md = itk.imread(md_ids[id]) + self.images['MD'].append(self.pack_and_process_image(md)) + + for mri_id in tqdm(mri_ids): + mri = mri_ids[mri_id].replace('T1w', 'modality').replace('T2w', 'modality') + if not os.path.exists(mri.replace('modality', 'T1w')) or not os.path.exists(mri.replace('modality', 'T2w')): + continue + + t1_mri = itk.imread(mri.replace('modality', 'T1w')) + self.images['T1'].append(self.pack_and_process_image(t1_mri)) + + t2_mri = itk.imread(mri.replace('modality', 'T2w')) + self.images['T2'].append(self.pack_and_process_image(t2_mri)) + + self.data_num = data_num + self.modalities = list(self.images.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0]) + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + image1 = np.random.choice(self.images[np.random.choice(self.modalities)]) + image2 = np.random.choice(self.images[np.random.choice(self.modalities)]) + + if not self.return_labels: + return blosc.unpack_array(image1), blosc.unpack_array(image2) + else: + return blosc.unpack_array(image1), blosc.unpack_array(image2), blosc.unpack_array(image1), blosc.unpack_array(image2) + + +class OAIMMDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/oai", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + ): + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.anatomies = ['knee'] + self.region_num = 1 + + dataset_dess = torch.load(f"{data_path}/dess_images.pt", map_location="cpu") + dataset_T2 = torch.load(f"{data_path}t2_images.pt", map_location="cpu") + + self.images = {'DESS': [], 'T2': []} + + self.images['DESS'] = list(map(lambda x: self.pack_and_process_image(x), tqdm(dataset_dess))) + self.images['T2'] = list(map(lambda x: self.pack_and_process_image(x), tqdm(dataset_T2))) + + self.data_num = data_num + self.modalities = list(self.images.keys()) + + def pack_and_process_image(self, image): + processed_image = self.process(image, self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + modality_1, modality_2 = random.choices(self.modalities, k=2) + img_a = random.choice(self.images[modality_1]) + img_b = random.choice(self.images[modality_2]) + + if self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(img_a), blosc.unpack_array(img_b) + else: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + +class L2rMRCTDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/AbdomenMRCT/", + data_num=1000, + desired_shape=None, + device="cpu", + phase = "train", + augmentation = True, + return_labels=False + ): + #inter-patient + self.phase = phase + self.device = device + self.return_labels = return_labels + self.anatomies = ['abdomen'] + self.region_num = 1 + + with open(f"{data_path}/AbdomenMRCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + + if self.phase == "train": + mr_samples = [c["image"] for c in data_info["training"]["0"]] #mr + ct_samples = [c["image"] for c in data_info["training"]["1"]] #ct + else: + mr_samples = [c["fixed"] for c in data_info["registration_test"]] + ct_samples = [c["moving"] for c in data_info["registration_test"]] + + self.images = {'mr' : [], 'ct' : []} + + if augmentation: + self.images['1-ct'] = [] + + for path in tqdm(mr_samples): + image = np.asarray(itk.imread(os.path.join(data_path, path))) + image = torch.Tensor(np.array(image)).unsqueeze(0).unsqueeze(0) + image = self.process_mr(image, desired_shape, device)[0] + self.images['mr'].append(self.pack(image)) + + for path in tqdm(ct_samples): + image = np.asarray(itk.imread(os.path.join(data_path, path))) + image = torch.Tensor(np.array(image)).unsqueeze(0).unsqueeze(0) + image = self.process_ct(image, desired_shape, device)[0] + self.images['ct'].append(self.pack(image)) + if augmentation: + self.images['1-ct'].append(self.pack(1-image)) + + self.data_num = data_num + self.modalities = list(self.images.keys()) + + def pack(self, image): + return blosc.pack_array(image.numpy()) + + def process_label(self, label, desired_shape=None, device="cpu"): + label = label.to(device) + if desired_shape is not None: + label = F.interpolate(label, desired_shape, mode="nearest") + return label.cpu() + + def process_ct(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def process_mr(self, img, desired_shape=None, device="cpu"): + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + modality1 = random.choice(self.modalities) + modality2 = random.choice(self.modalities) + + idx1 = random.randint(0, len(self.images[modality1])-1) + idx2 = random.randint(0, len(self.images[modality2])-1) + + img_a = self.images[modality1][idx1] + img_b = self.images[modality2][idx2] + + if self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(img_a), blosc.unpack_array(img_b) + else: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + +class UKBiobankDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/uk-biobank/", + data_num=1000, + desired_shape=None, + device="cpu", + phase = "train", + return_labels=False, + randomization = 'random' + ): + #contains 6 regions of the body - each region contains a list of images + fat_weighted_regions = torch.load(data_path + 'regions_fat.pt', map_location='cpu') + water_weighted_regions = torch.load(data_path + 'regions_water.pt', map_location='cpu') + + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.randomization = randomization + assert self.randomization in ['random', 'fixed'] + self.anatomies = ['abdomen', 'lung', 'knee'] + + self.images = {'fat': fat_weighted_regions, 'water': water_weighted_regions} + + for region in tqdm(range(6)): + if phase == "train": + self.images['fat'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['fat'][region][10:])) + self.images['water'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['water'][region][10:])) + else: + self.images['fat'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['fat'][region][:10])) + self.images['water'][region] = list(map(lambda x: self.pack_and_process_image(x), self.images['water'][region][:10])) + + self.data_num = data_num + self.modalities = ['fat', 'water'] + self.region_num = 6 + + def pack_and_process_image(self, image): + processed_image = self.process(image, self.desired_shape, self.device)[0] + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = torch.tensor(sitk.GetArrayFromImage(img)[None, None].astype(np.float32)) + img = img.to(device) + im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99) + img = torch.clip(img, im_min, im_max) + img = (img-im_min) / (im_max-im_min) + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + region = random.randint(0, self.region_num-1) + modality1 = random.choice(self.modalities) + modality2 = random.choice(self.modalities) + + index1 = random.randint(0, len(self.images[modality1][region])-1) + index2 = random.randint(0, len(self.images[modality2][region])-1) + + img_a = self.images[modality1][region][index1] + img_b = self.images[modality2][region][index2] + + if not self.return_labels: + return blosc.unpack_array(img_a), blosc.unpack_array(img_b) + + if self.randomization == 'random': + label1 = self.images[random.choice(self.modalities)][region][index1] + label2 = self.images[random.choice(self.modalities)][region][index2] + else: + modality = random.choice(self.modalities) + label1 = self.images[modality][region][index1] + label2 = self.images[modality][region][index2] + + return blosc.unpack_array(img_a), blosc.unpack_array(img_b), blosc.unpack_array(label1), blosc.unpack_array(label2) + + +class PancreasDataset(torch.utils.data.Dataset): + def __init__(self, + phase="train", + data_path=f"{DATASET_DIR}/pancreas/", + data_num=1000, + desired_shape=(175, 175, 175), + device="cpu", + return_labels=False, + ): + + with open(f"{data_path}/{phase}_set.txt", 'r') as f: + self.img_list = [line.strip() for line in f] + + self.desired_shape = desired_shape + self.data_num = data_num + self.img_dict = {} + self.return_labels = return_labels + self.modalities = ['ct', 'cbct'] + self.region_num = 1 + self.anatomies = ['abdomen'] + + for idx in range(len(self.img_list)): + ith_info = self.img_list[idx].split(" ") + ct_img_name = ith_info[0] + cb_img_name = ith_info[1] + + ct_img_itk = sitk.ReadImage(ct_img_name) + cb_img_itk = sitk.ReadImage(cb_img_name) + + ct_img_arr = sitk.GetArrayFromImage(ct_img_itk) + cb_img_arr = sitk.GetArrayFromImage(cb_img_itk) + + ct_img_arr, cb_img_arr = self.process_training_data(ct_img_arr, cb_img_arr) + + self.img_dict[ct_img_name] = blosc.pack_array(ct_img_arr) + self.img_dict[cb_img_name] = blosc.pack_array(cb_img_arr) + + def __len__(self): + return self.data_num + + def process(self, img, desired_shape=None, device="cpu"): + img = torch.tensor(img).unsqueeze(0).unsqueeze(0).cpu() + img = img.to(device).float() + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.squeeze().numpy().cpu().float() + + def process_training_data(self, ct_img_arr, cb_img_arr): + ct_img_arr = self.process(ct_img_arr, self.desired_shape) + cb_img_arr = self.process(cb_img_arr, self.desired_shape) + + return ct_img_arr, cb_img_arr + + def __getitem__(self, idx): + idx = np.random.randint(0, len(self.img_list)) + ith_info = self.img_list[idx].split(" ") + ct_img_name = ith_info[0] + cb_img_name = ith_info[1] + + ct_img_arr = blosc.unpack_array(self.img_dict[ct_img_name]) + cb_img_arr = blosc.unpack_array(self.img_dict[cb_img_name]) + + if not self.return_labels: + return ct_img_arr, cb_img_arr + else: + return ct_img_arr, cb_img_arr, ct_img_arr, cb_img_arr + +class L2rThoraxCBCTDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=f"{DATASET_DIR}/ThoraxCBCT", + data_num=1000, + desired_shape=None, + device="cpu", + return_labels=False, + ): + + with open(f"{data_path}/ThoraxCBCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + cases = [f"{data_path}/{c['image']}" for c in data_info["training"]] + + self.modalities = {'0000': [] , '0001': [], '0002': []} + self.desired_shape = desired_shape + self.device = device + self.return_labels = return_labels + self.anatomies = ['lung'] + self.region_num = 1 + + for case in cases: + modality = case.split('/')[-1].split('_')[-1].split('.')[0] + self.modalities[modality].append(self.pack_and_process_image(itk.imread(case))) + + self.data_num = data_num + + def pack_and_process_image(self, image): + processed_image = self.process(self.process(torch.tensor(np.asarray(image))[None, None], self.desired_shape, self.device)[0]) + array_image = processed_image.numpy() + return blosc.pack_array(array_image) + + def process(self, img, desired_shape=None, device="cpu"): + img = img.to(device).float() + img = (torch.clip(img.float(), -1000, 1000)+1000)/2000 + if desired_shape is not None: + img = F.interpolate(img, desired_shape, mode="trilinear") + return img.cpu().float() + + def __len__(self): + return self.data_num + + def __getitem__(self, idx): + modality1 = random.choice(list(self.modalities.keys())) + modality2 = random.choice(list(self.modalities.keys())) + + patient_id = random.randint(0, len(self.modalities[modality1])-1) + + image1 = self.modalities[modality1][patient_id] + image2 = self.modalities[modality2][patient_id] + + if not self.return_labels: + return blosc.unpack_array(image1), blosc.unpack_array(image2) + else: + return blosc.unpack_array(image1), blosc.unpack_array(image2), blosc.unpack_array(image1), blosc.unpack_array(image2) \ No newline at end of file diff --git a/training/train.py b/training/train.py index f64fb07..80ec6b2 100644 --- a/training/train.py +++ b/training/train.py @@ -12,7 +12,7 @@ import icon_registration.networks as networks from icon_registration.losses import ICONLoss, to_floats -from unigradicon import GradientICONSparse +from unigradicon import make_network def write_stats(writer, stats: ICONLoss, ite, prefix=""): for k, v in to_floats(stats)._asdict().items(): @@ -84,22 +84,6 @@ def augment(image_A, image_B): return warped_A, warped_B -def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5)): - dimension = len(input_shape) - 2 - inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)) - - for _ in range(2): - inner_net = icon.TwoStepRegistration( - icon.DownsampleRegistration(inner_net, dimension=dimension), - icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)) - ) - if include_last_step: - inner_net = icon.TwoStepRegistration(inner_net, icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension))) - - net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda) - net.assign_identity_map(input_shape) - return net - def train_kernel(optimizer, net, moving_image, fixed_image, writer, ite): optimizer.zero_grad() loss_object = net(moving_image, fixed_image) @@ -280,22 +264,6 @@ def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eva footsteps.output_dir + "checkpoints/Step_2_final.trch", ) -def finetune_execute(model, image_A, image_B, steps, lr=2e-5): - # state_dict = model.state_dict() - model.train() - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - for _ in range(steps): - optimizer.zero_grad() - loss_tuple = model(image_A, image_B) - print(loss_tuple) - loss_tuple[0].backward() - optimizer.step() - with torch.no_grad(): - loss = model(image_A, image_B) - # model.load_state_dict(state_dict) - model.eval() - return loss - if __name__ == "__main__": import argparse diff --git a/training/train_multi.py b/training/train_multi.py new file mode 100644 index 0000000..818e8e5 --- /dev/null +++ b/training/train_multi.py @@ -0,0 +1,393 @@ +import os +import random +from datetime import datetime + +from tqdm import tqdm +import torch +import torch.nn.functional as F +import dataset_multi +from torch.utils.data import ConcatDataset, DataLoader +from icon_registration.losses import ICONLoss, to_floats +from unigradicon import make_network +import math + +def write_stats(writer, stats: ICONLoss, ite, prefix=""): + for k, v in to_floats(stats)._asdict().items(): + writer.add_scalar(f"{prefix}{k}", v, ite) + +input_shape = [1, 1, 175, 175, 175] +DATA_NUM = 4000 +BATCH_SIZE= 4 +device_ids = [1, 0, 2, 3] +GPUS = len(device_ids) +EXP_DIR = "./results/multigradicon/" + +def get_multi_training_set(): + randomization = 'random' + + datasets = [ + dataset_multi.COPDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, ROI_only=False), + dataset_multi.BratsRegDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.L2rAbdomenDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization, augmentation = True), + dataset_multi.HCPDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.ABCDFAMDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.OAIMMDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.L2rMRCTDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, augmentation = True), + dataset_multi.UKBiobankDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + ] + + dataset = ConcatDataset(datasets) + + #modality based weighting + dataset_weights = [] + for dataset in datasets: + weight = [dataset.region_num * math.comb(len(dataset.modalities) + 1, 2)] * len(dataset) + dataset_weights.append(weight) + + return dataset, dataset_weights + +def get_multi_finetuning_set(): + randomization = 'random' + + datasets = [ + dataset_multi.PancreasDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.L2rThoraxCBCTDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.ABCDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.COPDDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, ROI_only=True), + dataset_multi.BratsRegDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.L2rAbdomenDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization, augmentation = False), + dataset_multi.HCPDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + dataset_multi.OAIMMDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True), + dataset_multi.L2rMRCTDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, augmentation = False), + dataset_multi.UKBiobankDataset(desired_shape=input_shape[2:], device=device_ids[0], data_num=DATA_NUM, return_labels=True, randomization=randomization), + ] + + dataset = ConcatDataset(datasets) + + #anatomic region based weighting + anatomies_count = {} + for dataset in datasets: + for anatomy in dataset.anatomies: + if anatomy not in anatomies_count: + anatomies_count[anatomy] = 0 + anatomies_count[anatomy] += 1 + + max_anatomy = max(anatomies_count.values()) + + dataset_weights = [] + for dataset in datasets: + weight = 0 + for anatomy in dataset.anatomies: + weight += max_anatomy / anatomies_count[anatomy] + dataset_weights.append([weight] * len(dataset)) + + return dataset, dataset_weights + +def augment(image_A, image_B, label_A, label_B): + device = image_A.device + identity_list = [] + for i in range(image_A.shape[0]): + identity = torch.tensor([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], device=device) + idxs = set((0, 1, 2)) + for j in range(3): + k = random.choice(list(idxs)) + idxs.remove(k) + identity[0, j, k] = 1 + identity = identity * (torch.randint_like(identity, 0, 2, device=device) * 2 - 1) + identity_list.append(identity) + + identity = torch.cat(identity_list) + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_A.shape[1] > 1: + # Then we have segmentations + warped_A = F.grid_sample(image_A[:, :1], forward_grid, padding_mode='border') + warped_A_seg = F.grid_sample(image_A[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_A = torch.cat([warped_A, warped_A_seg], axis=1) + else: + warped_A = F.grid_sample(image_A, forward_grid, padding_mode='border') + warped_label_A = F.grid_sample(label_A, forward_grid, padding_mode='border') + + noise = torch.randn((image_A.shape[0], 3, 4), device=device) + forward = identity + .05 * noise + + grid_shape = list(image_A.shape) + grid_shape[1] = 3 + forward_grid = F.affine_grid(forward, grid_shape) + + if image_B.shape[1] > 1: + # Then we have segmentations + warped_B = F.grid_sample(image_B[:, :1], forward_grid, padding_mode='border') + warped_B_seg = F.grid_sample(image_B[:, 1:], forward_grid, mode='nearest', padding_mode='border') + warped_B = torch.cat([warped_B, warped_B_seg], axis=1) + else: + warped_B = F.grid_sample(image_B, forward_grid, padding_mode='border') + warped_label_B = F.grid_sample(label_B, forward_grid, padding_mode='border') + + return warped_A, warped_B, warped_label_A, warped_label_B + +def train_kernel(optimizer, net, moving_image, fixed_image, moving_label, fixed_label, writer, ite): + optimizer.zero_grad() + loss_object = net(moving_image, fixed_image, moving_label, fixed_label) + loss = torch.mean(loss_object.all_loss) + loss.backward() + optimizer.step() + # print(to_floats(loss_object)) + write_stats(writer, loss_object, ite, prefix="train/") + +def train( + net, + optimizer, + data_loader, + val_data_loader, + epochs=200, + eval_period=-1, + save_period=-1, + step_callback=(lambda net: None), + unwrapped_net=None, + data_augmenter=None, +): + """A training function intended for long running experiments, with tensorboard logging + and model checkpoints. Use for medical registration training + """ + import footsteps + from torch.utils.tensorboard import SummaryWriter + + if unwrapped_net is None: + unwrapped_net = net + + loss_curve = [] + writer = SummaryWriter( + footsteps.output_dir + "/logs/" + datetime.now().strftime("%Y%m%d-%H%M%S"), + flush_secs=30, + ) + + iteration = 0 + for epoch in tqdm(range(epochs)): + for moving_image, fixed_image, moving_label, fixed_label in data_loader: + moving_image, fixed_image, moving_label, fixed_label = moving_image.cuda(), fixed_image.cuda(), moving_label.cuda(), fixed_label.cuda() + if data_augmenter is not None: + with torch.no_grad(): + moving_image, fixed_image, moving_image, fixed_image = data_augmenter(moving_image, fixed_image, moving_label, fixed_label) + train_kernel(optimizer, net, moving_image, fixed_image, moving_image, fixed_image, writer, iteration) + iteration += 1 + + step_callback(unwrapped_net) + + + if epoch % save_period == 0: + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + + if epoch % eval_period == 0: + visualization_moving, visualization_fixed, _, _ = next(iter(val_data_loader)) + visualization_moving, visualization_fixed = visualization_moving[:, :1].cuda(), visualization_fixed[:, :1].cuda() + unwrapped_net.eval() + warped = [] + with torch.no_grad(): + eval_loss = unwrapped_net(visualization_moving, visualization_fixed) + write_stats(writer, eval_loss, epoch, prefix="val/") + warped = unwrapped_net.warped_image_A.cpu() + del eval_loss + unwrapped_net.clean() + unwrapped_net.train() + + def render(im): + if len(im.shape) == 5: + im = im[:, :, :, im.shape[3] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + writer.add_images( + "moving_image", render(visualization_moving[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "fixed_image", render(visualization_fixed[:4]), epoch, dataformats="NCHW" + ) + writer.add_images( + "warped_moving_image", + render(warped), + epoch, + dataformats="NCHW", + ) + writer.add_images( + "difference", + render(torch.clip((warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1)), + epoch, + dataformats="NCHW", + ) + + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "checkpoints/optimizer_weights_" + str(epoch), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/network_weights_" + str(epoch), + ) + +def train_two_stage(input_shape, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period, resume_from): + + net = make_network(input_shape, include_last_step=False, use_label=True) + + torch.cuda.set_device(device_ids[0]) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # device = f"cuda:{device_ids[0]}" + + # Continue train + if resume_from != "": + print("Resume from: ", resume_from) + net.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + if GPUS == 1: + net_par = net.cuda() + else: + net_par = torch.nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_par.parameters(), lr=0.00005) + + if resume_from != "": + optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_par.train() + + print("start train.") + train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, + epochs=epochs[0], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_1_final.trch", + ) + + net_2 = make_network(input_shape, include_last_step=True) + + net_2.regis_net.netPhi.load_state_dict(net.regis_net.state_dict()) + + # Continue train + # if resume_from != "": + # print("Resume from: ", resume_from) + # net_2.regis_net.load_state_dict(torch.load(resume_from, map_location="cpu")) + + del net + del net_par + del optimizer + + if GPUS == 1: + net_2_par = net_2.cuda() + else: + net_2_par = torch.nn.DataParallel(net_2, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_2_par.parameters(), lr=0.00005) + + # if resume_from != "": + # optimizer.load_state_dict(torch.load(resume_from.replace("network_weights_", "optimizer_weights_"), map_location="cpu")) + + net_2_par.train() + + # We're being weird by training two networks in one script. This hack keeps + # the second training from overwriting the outputs of the first. + footsteps.output_dir_impl = footsteps.output_dir + "2nd_step/" + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train(net_2_par, optimizer, data_loader, val_data_loader, unwrapped_net=net_2, epochs=epochs[1], eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net_2.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_2_final.trch", + ) + + return net_2 + +def finetune(net, data_loader, val_data_loader, GPUS, epochs, eval_period, save_period): + if GPUS == 1: + net_par = net.cuda() + else: + net_par = torch.nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]).cuda() + optimizer = torch.optim.Adam(net_par.parameters(), lr=0.00005) + + net_par.train() + + footsteps.output_dir_impl = footsteps.output_dir + "finetune/" + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, epochs=epochs, eval_period=eval_period, save_period=save_period, data_augmenter=augment) + + torch.save( + net.regis_net.state_dict(), + footsteps.output_dir + "checkpoints/Step_2_final.trch", + ) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--resume_from", required=False, default="") + args = parser.parse_args() + resume_from = args.resume_from + + import footsteps + footsteps.initialize(output_root=EXP_DIR) + + dataset, weights = get_multi_training_set() + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE*GPUS, + num_workers=4, + drop_last=True, + sampler=torch.utils.data.WeightedRandomSampler(weights, DATA_NUM) + ) + + val_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=4, + drop_last=True, + ) + print("Finish data loading...") + + os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) + + print("Start training...") + net = train_two_stage(input_shape, dataloader, val_dataloader, GPUS, [801,201], 20, 20, resume_from) + + del dataloader, val_dataloader, dataset, weights + + print("Start finetuning...") + + fine_dataset, fine_weights = get_multi_finetuning_set() + + fine_dataloader = DataLoader( + fine_dataset, + batch_size=BATCH_SIZE*GPUS, + num_workers=4, + drop_last=True, + sampler=torch.utils.data.WeightedRandomSampler(fine_weights, DATA_NUM) + ) + + fine_val_dataloader = DataLoader( + fine_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=4, + drop_last=True, + ) + + finetune(net, fine_dataloader, fine_val_dataloader, GPUS, 100, 20, 20) \ No newline at end of file From f2e870f82c0ab29d10ef7fa74d3169bcb5ad6b8d Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Tue, 29 Oct 2024 10:22:57 -0400 Subject: [PATCH 4/4] Fix folder hierarchy for footsteps --- training/train_multi.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/training/train_multi.py b/training/train_multi.py index 818e8e5..e9ae831 100644 --- a/training/train_multi.py +++ b/training/train_multi.py @@ -323,14 +323,14 @@ def finetune(net, data_loader, val_data_loader, GPUS, epochs, eval_period, save_ net_par.train() - footsteps.output_dir_impl = footsteps.output_dir + "finetune/" + footsteps.output_dir_impl = footsteps.output_dir.split("2nd_step/")[0] + "finetune/" os.makedirs(footsteps.output_dir + "checkpoints", exist_ok=True) train(net_par, optimizer, data_loader, val_data_loader, unwrapped_net=net, epochs=epochs, eval_period=eval_period, save_period=save_period, data_augmenter=augment) torch.save( net.regis_net.state_dict(), - footsteps.output_dir + "checkpoints/Step_2_final.trch", + footsteps.output_dir + "checkpoints/finetune_final.trch", ) if __name__ == "__main__": @@ -390,4 +390,7 @@ def finetune(net, data_loader, val_data_loader, GPUS, epochs, eval_period, save_ drop_last=True, ) + print("Finish data loading...") + + print("Start finetuning...") finetune(net, fine_dataloader, fine_val_dataloader, GPUS, 100, 20, 20) \ No newline at end of file