Skip to content

Commit

Permalink
Enable user to load multigradicon.
Browse files Browse the repository at this point in the history
  • Loading branch information
lintian-a committed Jul 31, 2024
1 parent 6ed8200 commit 067d3ff
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
32 changes: 30 additions & 2 deletions src/unigradicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,30 @@ def make_sim(similarity):
else:
raise ValueError(f"Similarity measure {similarity} not recognized. Choose from [lncc, lncc2, mind].")

def get_multigradicon(loss_fn=icon.LNCC(sigma=5)):
net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn)
from os.path import exists
weights_location = "network_weights/multigradicon1.0/Step_2_final.trch"
if not exists(weights_location):
print("Downloading pretrained multigradicon model")
import urllib.request
import os
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/multigradicon_weights/Step_2_final.trch"
os.makedirs("network_weights/multigradicon1.0/", exist_ok=True)
urllib.request.urlretrieve(download_path, weights_location)
print(f"Loading weights from {weights_location}")
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"))
net.regis_net.load_state_dict(trained_weights)
net.to(config.device)
net.eval()
return net

def get_unigradicon(loss_fn=icon.LNCC(sigma=5)):
net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn)
from os.path import exists
weights_location = "network_weights/unigradicon1.0/Step_2_final.trch"
if not exists(weights_location):
print("Downloading pretrained model")
print("Downloading pretrained unigradicon model")
import urllib.request
import os
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch"
Expand All @@ -186,6 +204,14 @@ def get_unigradicon(loss_fn=icon.LNCC(sigma=5)):
net.eval()
return net

def get_model_from_model_zoo(model_name="unigradicon", loss_fn=icon.LNCC(sigma=5)):
if model_name == "unigradicon":
return get_unigradicon(loss_fn)
elif model_name == "multigradicon":
return get_multigradicon(loss_fn)
else:
raise ValueError(f"Model {model_name} not recognized. Choose from [unigradicon, multigradicon].")

def quantile(arr: torch.Tensor, q):
arr = arr.flatten()
l = len(arr)
Expand Down Expand Up @@ -252,10 +278,12 @@ def main():
default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.")
parser.add_argument("--io_sim", required=False,
default="lncc", help="The similarity measure used in IO. Default is LNCC. Choose from [lncc, lncc2, mind].")
parser.add_argument("--model", required=False,
default="unigradicon", help="The model to load. Default is unigradicon. Choose from [unigradicon, multigradicon].")

args = parser.parse_args()

net = get_unigradicon(make_sim(args.io_sim))
net = get_model_from_model_zoo(args.model, make_sim(args.io_sim))

fixed = itk.imread(args.fixed)
moving = itk.imread(args.moving)
Expand Down
14 changes: 10 additions & 4 deletions tests/test_command_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import subprocess
import os
import torch


class TestCommandInterface(unittest.TestCase):
Expand All @@ -14,6 +15,7 @@ def __init__(self, methodName: str = "runTest") -> None:
self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR
self.test_temp_dir = f"{self.test_data_dir}/temp"
os.makedirs(self.test_temp_dir, exist_ok=True)
self.device = torch.cuda.current_device()

def test_register_unigradicon_inference(self):
subprocess.run([
Expand Down Expand Up @@ -58,8 +60,8 @@ def test_register_unigradicon_inference(self):

# remove temp file
os.remove(f"{self.test_temp_dir}/transform.hdf5")

def test_register_unigradicon_io(self):
def test_register_multigradicon_inference(self):
subprocess.run([
"unigradicon-register",
"--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz",
Expand All @@ -68,7 +70,9 @@ def test_register_unigradicon_io(self):
"--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz",
"--moving_modality", "ct",
"--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz",
"--transform_out", f"{self.test_temp_dir}/transform.hdf5"
"--transform_out", f"{self.test_temp_dir}/transform.hdf5",
"--io_iterations", "None",
"--model", "multigradicon"
])

# load transform
Expand Down Expand Up @@ -97,8 +101,10 @@ def test_register_unigradicon_io(self):
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
print(np.mean(dists))
self.assertLess(np.mean(dists), 1.5)
self.assertLess(np.mean(dists), 3.8)

# remove temp file
os.remove(f"{self.test_temp_dir}/transform.hdf5")



0 comments on commit 067d3ff

Please sign in to comment.