Skip to content

Commit

Permalink
Merge branch 'main' into test-unigradicon-warp-issue
Browse files Browse the repository at this point in the history
  • Loading branch information
HastingsGreer authored Nov 20, 2024
2 parents 1698789 + 8a11d35 commit bcd01db
Show file tree
Hide file tree
Showing 11 changed files with 2,050 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gpu-test-action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: gpu-tests
on:
pull_request:
push:
branches: main
branches: [dev, main]

jobs:
test-linux:
Expand Down
49 changes: 40 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,29 @@ The result is a deep-learning-based registration model that works well across da

![teaser](IntroFigure.jpg?raw=true)

**uniGradICON: A Foundation Model for Medical Image Registration**
Tian, Lin and Greer, Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard and Niethammer, Marc
_MICCAI 2024_ https://arxiv.org/abs/2403.05780

**multiGradICON: A Foundation Model for Multimodal Medical Image Registration**
Demir, Basar and Tian, Lin and Greer, Thomas Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard Jarrett and Ebrahim, Ebrahim and Niethammer, Marc
_MICCAI Workshop on Biomedical Image Registration (WBIR) 2024_ https://arxiv.org/abs/2408.00221

Please (currently) cite as:
```
@misc{tian2024unigradicon,
title={uniGradICON: A Foundation Model for Medical Image Registration},
author={Lin Tian and Hastings Greer and Roland Kwitt and Francois-Xavier Vialard and Raul San Jose Estepar and Sylvain Bouix and Richard Rushmore and Marc Niethammer},
year={2024},
eprint={2403.05780},
archivePrefix={arXiv},
primaryClass={cs.CV}
@article{tian2024unigradicon,
title={uniGradICON: A Foundation Model for Medical Image Registration},
author={Tian, Lin and Greer, Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard and Niethammer, Marc},
journal={arXiv preprint arXiv:2403.05780},
year={2024}
}
```
```
@article{demir2024multigradicon,
title={multiGradICON: A Foundation Model for Multimodal Medical Image Registration},
author={Demir, Basar and Tian, Lin and Greer, Thomas Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard Jarrett and Ebrahim, Ebrahim and Niethammer, Marc},
journal={arXiv preprint arXiv:2408.00221},
year={2024}
}
```

Expand Down Expand Up @@ -204,12 +218,25 @@ unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=Reg
```

To register without instance optimization
To register without instance optimization (IO)
```
unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations None
```

To warp
To use a different similarity measure in the IO. We currently support three similarity measures
- LNCC: lncc
- Squared LNCC: lncc2
- MIND SSC: mind
```
unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations 50 --io_sim lncc2
```

To load specific model weight in the inference. We currently support uniGradICON and multiGradICON.
```
unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --model multigradicon
```

To warp an image
```
unigradicon-warp --fixed [fixed_image_file_name] --moving [moving_image_file_name] --transform trans.hdf5 --warped_moving_out warped.nii.gz --linear
```
Expand All @@ -218,8 +245,12 @@ To warp a label map
```
unigradicon-warp --fixed [fixed_image_file_name] --moving [moving_image_segmentation_file_name] --transform trans.hdf5 --warped_moving_out warped_seg.nii.gz --nearest_neighbor
```

We also provide a [colab](https://colab.research.google.com/drive/1JuFL113WN3FHCoXG-4fiBTWIyYpwGyGy?usp=sharing) demo.

## Slicer Extension

A Slicer extensions is available [here](https://github.com/uncbiag/SlicerUniGradICON?tab=readme-ov-file) (and hopefully will soon be available via the Slicer Extension Manager).

## Plays well with others

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
icon_registration>=1.1.5
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[metadata]
name = unigradicon
version = 1.0.2
version = 1.0.3
author = Lin Tian
author_email =
author_email = [email protected]
description = a foundation model for medical image registration
long_description = file: README.md
long_description_content_type = text/markdown
Expand All @@ -21,7 +21,7 @@ packages = find:
python_requires = >=3.7

install_requires =
icon_registration>=1.1.4
icon_registration>=1.1.5

[options.packages.find]
where = src
Expand Down
97 changes: 79 additions & 18 deletions src/unigradicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,27 @@
from icon_registration.mermaidlite import compute_warped_image_multiNC
import icon_registration.itk_wrapper



input_shape = [1, 1, 175, 175, 175]

class GradientICONSparse(network_wrappers.RegistrationModule):
def __init__(self, network, similarity, lmbda):
def __init__(self, network, similarity, lmbda, use_label=False):

super().__init__()

self.regis_net = network
self.lmbda = lmbda
self.similarity = similarity
self.use_label = use_label

def forward(self, image_A, image_B):
def forward(self, image_A, image_B, label_A=None, label_B=None):

assert self.identity_map.shape[2:] == image_A.shape[2:]
assert self.identity_map.shape[2:] == image_B.shape[2:]
if self.use_label:
label_A = image_A if label_A is None else label_A
label_B = image_B if label_B is None else label_B
assert self.identity_map.shape[2:] == label_A.shape[2:]
assert self.identity_map.shape[2:] == label_B.shape[2:]

# Tag used elsewhere for optimization.
# Must be set at beginning of forward b/c not preserved by .cuda() etc
Expand Down Expand Up @@ -75,10 +79,29 @@ def forward(self, image_A, image_B):
1,
zero_boundary=True
)

similarity_loss = self.similarity(
self.warped_image_A, image_B
) + self.similarity(self.warped_image_B, image_A)

if self.use_label:
self.warped_label_A = compute_warped_image_multiNC(
torch.cat([label_A, inbounds_tag], axis=1) if inbounds_tag is not None else label_A,
self.phi_AB_vectorfield,
self.spacing,
1,
)

self.warped_label_B = compute_warped_image_multiNC(
torch.cat([label_B, inbounds_tag], axis=1) if inbounds_tag is not None else label_B,
self.phi_BA_vectorfield,
self.spacing,
1,
)

similarity_loss = self.similarity(
self.warped_label_A, label_B
) + self.similarity(self.warped_label_B, label_A)
else:
similarity_loss = self.similarity(
self.warped_image_A, image_B
) + self.similarity(self.warped_image_B, image_A)

if len(self.input_shape) - 2 == 3:
Iepsilon = (
Expand Down Expand Up @@ -142,8 +165,10 @@ def forward(self, image_A, image_B):

def clean(self):
del self.phi_AB, self.phi_BA, self.phi_AB_vectorfield, self.phi_BA_vectorfield, self.warped_image_A, self.warped_image_B
if self.use_label:
del self.warped_label_A, self.warped_label_B

def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5)):
def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5), use_label=False):
dimension = len(input_shape) - 2
inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension))

Expand All @@ -155,17 +180,44 @@ def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.L
if include_last_step:
inner_net = icon.TwoStepRegistration(inner_net, icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)))

net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda)
net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda, use_label=use_label)
net.assign_identity_map(input_shape)
return net

def make_sim(similarity):
if similarity == "lncc":
return icon.LNCC(sigma=5)
elif similarity == "lncc2":
return icon. SquaredLNCC(sigma=5)
elif similarity == "mind":
return icon.MINDSSC(radius=2, dilation=2)
else:
raise ValueError(f"Similarity measure {similarity} not recognized. Choose from [lncc, lncc2, mind].")

def get_multigradicon(loss_fn=icon.LNCC(sigma=5)):
net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn)
from os.path import exists
weights_location = "network_weights/multigradicon1.0/Step_2_final.trch"
if not exists(weights_location):
print("Downloading pretrained multigradicon model")
import urllib.request
import os
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/multigradicon_weights/Step_2_final.trch"
os.makedirs("network_weights/multigradicon1.0/", exist_ok=True)
urllib.request.urlretrieve(download_path, weights_location)
print(f"Loading weights from {weights_location}")
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"))
net.regis_net.load_state_dict(trained_weights)
net.to(config.device)
net.eval()
return net

def get_unigradicon():
net = make_network(input_shape, include_last_step=True)
def get_unigradicon(loss_fn=icon.LNCC(sigma=5)):
net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn)
from os.path import exists
weights_location = "network_weights/unigradicon1.0/Step_2_final.trch"
if not exists(weights_location):
print("Downloading pretrained model")
print("Downloading pretrained unigradicon model")
import urllib.request
import os
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch"
Expand All @@ -177,6 +229,14 @@ def get_unigradicon():
net.eval()
return net

def get_model_from_model_zoo(model_name="unigradicon", loss_fn=icon.LNCC(sigma=5)):
if model_name == "unigradicon":
return get_unigradicon(loss_fn)
elif model_name == "multigradicon":
return get_multigradicon(loss_fn)
else:
raise ValueError(f"Model {model_name} not recognized. Choose from [unigradicon, multigradicon].")

def quantile(arr: torch.Tensor, q):
arr = arr.flatten()
l = len(arr)
Expand All @@ -202,7 +262,7 @@ def preprocess(image, modality="ct", segmentation=None):
min_ = -1000
max_ = 1000
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
image = itk.clamp_image_filter(image, Bounds=(-1000, 1000))
image = itk.clamp_image_filter(image, Bounds=(min_, max_))
elif modality == "mri":
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
min_, _ = itk.image_intensity_min_max(image)
Expand Down Expand Up @@ -241,10 +301,14 @@ def main():
default=None, type=str, help="The path to save the warped image.")
parser.add_argument("--io_iterations", required=False,
default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.")
parser.add_argument("--io_sim", required=False,
default="lncc", help="The similarity measure used in IO. Default is LNCC. Choose from [lncc, lncc2, mind].")
parser.add_argument("--model", required=False,
default="unigradicon", help="The model to load. Default is unigradicon. Choose from [unigradicon, multigradicon].")

args = parser.parse_args()

net = get_unigradicon()
net = get_model_from_model_zoo(args.model, make_sim(args.io_sim))

fixed = itk.imread(args.fixed)
moving = itk.imread(args.moving)
Expand Down Expand Up @@ -345,6 +409,3 @@ def maybe_cast(img: itk.Image):

return img, maybe_cast_back




110 changes: 110 additions & 0 deletions tests/test_command_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import itk
import numpy as np
import unittest
import icon_registration.test_utils

import subprocess
import os
import torch


class TestCommandInterface(unittest.TestCase):
def __init__(self, methodName: str = "runTest") -> None:
super().__init__(methodName)
icon_registration.test_utils.download_test_data()
self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR
self.test_temp_dir = f"{self.test_data_dir}/temp"
os.makedirs(self.test_temp_dir, exist_ok=True)
self.device = torch.cuda.current_device()

def test_register_unigradicon_inference(self):
subprocess.run([
"unigradicon-register",
"--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz",
"--fixed_modality", "ct",
"--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz",
"--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz",
"--moving_modality", "ct",
"--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz",
"--transform_out", f"{self.test_temp_dir}/transform.hdf5",
"--io_iterations", "None"
])

# load transform
phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0]

assert isinstance(phi_AB, itk.CompositeTransform)

insp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_iBH_xyz_r1.txt"
)
)
exp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_eBH_xyz_r1.txt"
)
)

dists = []
for i in range(len(insp_points)):
px, py = (
insp_points[i],
np.array(phi_AB.TransformPoint(tuple(exp_points[i]))),
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
print(np.mean(dists))
self.assertLess(np.mean(dists), 2.1)

# remove temp file
os.remove(f"{self.test_temp_dir}/transform.hdf5")

def test_register_multigradicon_inference(self):
subprocess.run([
"unigradicon-register",
"--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz",
"--fixed_modality", "ct",
"--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz",
"--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz",
"--moving_modality", "ct",
"--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz",
"--transform_out", f"{self.test_temp_dir}/transform.hdf5",
"--io_iterations", "None",
"--model", "multigradicon"
])

# load transform
phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0]

assert isinstance(phi_AB, itk.CompositeTransform)

insp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_iBH_xyz_r1.txt"
)
)
exp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_eBH_xyz_r1.txt"
)
)

dists = []
for i in range(len(insp_points)):
px, py = (
insp_points[i],
np.array(phi_AB.TransformPoint(tuple(exp_points[i]))),
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
print(np.mean(dists))
self.assertLess(np.mean(dists), 3.8)

# remove temp file
os.remove(f"{self.test_temp_dir}/transform.hdf5")



Loading

0 comments on commit bcd01db

Please sign in to comment.