From 067d3ffb656dbeb010b61765a22c5911353db2dc Mon Sep 17 00:00:00 2001 From: Lin Tian Date: Wed, 31 Jul 2024 10:18:12 -0400 Subject: [PATCH] Enable user to load multigradicon. --- src/unigradicon/__init__.py | 32 ++++++++++++++++++++++++++++++-- tests/test_command_arguments.py | 14 ++++++++++---- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/unigradicon/__init__.py b/src/unigradicon/__init__.py index a5b1113..d284ee5 100644 --- a/src/unigradicon/__init__.py +++ b/src/unigradicon/__init__.py @@ -169,12 +169,30 @@ def make_sim(similarity): else: raise ValueError(f"Similarity measure {similarity} not recognized. Choose from [lncc, lncc2, mind].") +def get_multigradicon(loss_fn=icon.LNCC(sigma=5)): + net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn) + from os.path import exists + weights_location = "network_weights/multigradicon1.0/Step_2_final.trch" + if not exists(weights_location): + print("Downloading pretrained multigradicon model") + import urllib.request + import os + download_path = "https://github.com/uncbiag/uniGradICON/releases/download/multigradicon_weights/Step_2_final.trch" + os.makedirs("network_weights/multigradicon1.0/", exist_ok=True) + urllib.request.urlretrieve(download_path, weights_location) + print(f"Loading weights from {weights_location}") + trained_weights = torch.load(weights_location, map_location=torch.device("cpu")) + net.regis_net.load_state_dict(trained_weights) + net.to(config.device) + net.eval() + return net + def get_unigradicon(loss_fn=icon.LNCC(sigma=5)): net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn) from os.path import exists weights_location = "network_weights/unigradicon1.0/Step_2_final.trch" if not exists(weights_location): - print("Downloading pretrained model") + print("Downloading pretrained unigradicon model") import urllib.request import os download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch" @@ -186,6 +204,14 @@ def get_unigradicon(loss_fn=icon.LNCC(sigma=5)): net.eval() return net +def get_model_from_model_zoo(model_name="unigradicon", loss_fn=icon.LNCC(sigma=5)): + if model_name == "unigradicon": + return get_unigradicon(loss_fn) + elif model_name == "multigradicon": + return get_multigradicon(loss_fn) + else: + raise ValueError(f"Model {model_name} not recognized. Choose from [unigradicon, multigradicon].") + def quantile(arr: torch.Tensor, q): arr = arr.flatten() l = len(arr) @@ -252,10 +278,12 @@ def main(): default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.") parser.add_argument("--io_sim", required=False, default="lncc", help="The similarity measure used in IO. Default is LNCC. Choose from [lncc, lncc2, mind].") + parser.add_argument("--model", required=False, + default="unigradicon", help="The model to load. Default is unigradicon. Choose from [unigradicon, multigradicon].") args = parser.parse_args() - net = get_unigradicon(make_sim(args.io_sim)) + net = get_model_from_model_zoo(args.model, make_sim(args.io_sim)) fixed = itk.imread(args.fixed) moving = itk.imread(args.moving) diff --git a/tests/test_command_arguments.py b/tests/test_command_arguments.py index 92e7221..069839d 100644 --- a/tests/test_command_arguments.py +++ b/tests/test_command_arguments.py @@ -5,6 +5,7 @@ import subprocess import os +import torch class TestCommandInterface(unittest.TestCase): @@ -14,6 +15,7 @@ def __init__(self, methodName: str = "runTest") -> None: self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR self.test_temp_dir = f"{self.test_data_dir}/temp" os.makedirs(self.test_temp_dir, exist_ok=True) + self.device = torch.cuda.current_device() def test_register_unigradicon_inference(self): subprocess.run([ @@ -58,8 +60,8 @@ def test_register_unigradicon_inference(self): # remove temp file os.remove(f"{self.test_temp_dir}/transform.hdf5") - - def test_register_unigradicon_io(self): + + def test_register_multigradicon_inference(self): subprocess.run([ "unigradicon-register", "--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz", @@ -68,7 +70,9 @@ def test_register_unigradicon_io(self): "--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz", "--moving_modality", "ct", "--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz", - "--transform_out", f"{self.test_temp_dir}/transform.hdf5" + "--transform_out", f"{self.test_temp_dir}/transform.hdf5", + "--io_iterations", "None", + "--model", "multigradicon" ]) # load transform @@ -97,8 +101,10 @@ def test_register_unigradicon_io(self): ) dists.append(np.sqrt(np.sum((px - py) ** 2))) print(np.mean(dists)) - self.assertLess(np.mean(dists), 1.5) + self.assertLess(np.mean(dists), 3.8) # remove temp file os.remove(f"{self.test_temp_dir}/transform.hdf5") + +