Skip to content

Commit

Permalink
Merge pull request #9 from uncbiag/fix-preprocessing
Browse files Browse the repository at this point in the history
Fix preprocessing
  • Loading branch information
lintian-a authored Apr 15, 2024
2 parents ed2b54e + 0a31f3b commit 42b10fe
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_readme_works.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ jobs:
wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_2.nrrd
- name: Test
run: |
unigradicon-register --fixed=RegLib_C01_2.nrrd --moving=RegLib_C01_1.nrrd \
unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri \
--transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations=None
unigradicon-register --fixed=RegLib_C01_2.nrrd --moving=RegLib_C01_1.nrrd \
unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri \
--transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations=3
unigradicon-warp --fixed=RegLib_C01_2.nrrd --moving=RegLib_C01_1.nrrd \
--transform=trans.hdf5 --warped_moving_out=warped_2_C01_1.nrrd --nearest_neighbor
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pyc
*.egg-info
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ pip install unigradicon
wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_1.nrrd
wget https://www.hgreer.com/assets/slicer_mirror/RegLib_C01_2.nrrd
unigradicon-register --fixed=RegLib_C01_2.nrrd --moving=RegLib_C01_1.nrrd \
--transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd
unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd
```
We also provide a [colab](https://colab.research.google.com/drive/1JuFL113WN3FHCoXG-4fiBTWIyYpwGyGy?usp=sharing) demo.
Expand Down
94 changes: 79 additions & 15 deletions src/unigradicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,36 +177,98 @@ def get_unigradicon():
net.eval()
return net

def preprocess(image):
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
max_ = np.max(np.array(image))
image = itk.shift_scale_image_filter(image, shift=0., scale = .9 / max_)

def quantile(arr: torch.Tensor, q):
arr = arr.flatten()
l = len(arr)
return torch.kthvalue(arr, int(q * l)).values

def apply_mask(image, segmentation):
segmentation_cast_filter = itk.CastImageFilter[type(segmentation),
itk.Image.F3].New()
segmentation_cast_filter.SetInput(segmentation)
segmentation_cast_filter.Update()
segmentation = segmentation_cast_filter.GetOutput()
mask_filter = itk.MultiplyImageFilter[itk.Image.F3, itk.Image.F3,
itk.Image.F3].New()

mask_filter.SetInput1(image)
mask_filter.SetInput2(segmentation)
mask_filter.Update()

return mask_filter.GetOutput()

def preprocess(image, modality="ct", segmentation=None):
if modality == "ct":
min_ = -1000
max_ = 1000
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
image = itk.clamp_image_filter(image, Bounds=(-1000, 1000))
elif modality == "mri":
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
min_, _ = itk.image_intensity_min_max(image)
max_ = quantile(torch.tensor(np.array(image)), .99).item()
image = itk.clamp_image_filter(image, Bounds=(min_, max_))
else:
raise ValueError(f"{modality} not recognized. Use 'ct' or 'mri'.")

image = itk.shift_scale_image_filter(image, shift=-min_, scale = 1/(max_-min_))

if segmentation is not None:
image = apply_mask(image, segmentation)
return image

def main():
import itk
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--fixed", required=True)
parser.add_argument("--moving", required=True)
parser.add_argument("--transform_out", required=True)
parser.add_argument("--warped_moving_out", default=None)
parser.add_argument("--io_iterations", default="50")
parser = argparse.ArgumentParser(description="Register two images using unigradicon.")
parser.add_argument("--fixed", required=True, type=str,
help="The path of the fixed image.")
parser.add_argument("--moving", required=True, type=str,
help="The path of the fixed image.")
parser.add_argument("--fixed_modality", required=True,
type=str, help="The modality of the fixed image. Should be 'ct' or 'mri'.")
parser.add_argument("--moving_modality", required=True,
type=str, help="The modality of the moving image. Should be 'ct' or 'mri'.")
parser.add_argument("--fixed_segmentation", required=False,
type=str, help="The path of the segmentation map of the fixed image. \
This map will be applied to the fixed image before registration.")
parser.add_argument("--moving_segmentation", required=False,
type=str, help="The path of the segmentation map of the moving image. \
This map will be applied to the moving image before registration.")
parser.add_argument("--transform_out", required=True,
type=str, help="The path to save the transform.")
parser.add_argument("--warped_moving_out", required=False,
default=None, type=str, help="The path to save the warped image.")
parser.add_argument("--io_iterations", required=False,
default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.")

args = parser.parse_args()

net = get_unigradicon()

fixed = itk.imread(args.fixed)
moving = itk.imread(args.moving)

if args.fixed_segmentation is not None:
fixed_segmentation = itk.imread(args.fixed_segmentation)
else:
fixed_segmentation = None

if args.moving_segmentation is not None:
moving_segmentation = itk.imread(args.moving_segmentation)
else:
moving_segmentation = None

if args.io_iterations == "None":
io_iterations = None
else:
io_iterations = int(args.io_iterations)

phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(net,preprocess(moving), preprocess(fixed), finetune_steps=io_iterations)
phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(
net,
preprocess(moving, args.moving_modality, moving_segmentation),
preprocess(fixed, args.fixed_modality, fixed_segmentation),
finetune_steps=io_iterations)

itk.transformwrite([phi_AB], args.transform_out)

Expand All @@ -224,9 +286,11 @@ def main():
def warp_command():
import itk
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--fixed", required=True)
parser.add_argument("--moving", required=True)
parser = argparse.ArgumentParser(description="Warp an image with given transformation.")
parser.add_argument("--fixed", required=True, type=str,
help="The path of the fixed image.")
parser.add_argument("--moving", required=True, type=str,
help="The path of the moving image.")
parser.add_argument("--transform")
parser.add_argument("--warped_moving_out", required=True)
parser.add_argument('--nearest_neighbor', action='store_true')
Expand Down
Empty file added tests/__init__.py
Empty file.
205 changes: 205 additions & 0 deletions tests/test_itk_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import itk
import numpy as np
import unittest
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


import icon_registration.test_utils
import icon_registration.itk_wrapper
import icon_registration.pretrained_models

from unigradicon import preprocess, get_unigradicon


class TestItkInterface(unittest.TestCase):
def __init__(self, methodName: str = "runTest") -> None:
super().__init__(methodName)
icon_registration.test_utils.download_test_data()
self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR


def test_register_pair(self):
fixed_path = f"{self.test_data_dir}/brain_test_data/8_T1w_acpc_dc_restore_brain.nii.gz"
moving_path = f"{self.test_data_dir}/brain_test_data/2_T1w_acpc_dc_restore_brain.nii.gz"

# Run ITK interface
fixed = itk.imread(fixed_path)
moving = itk.imread(moving_path)

net = get_unigradicon()

phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(
net,
preprocess(moving, "mri"),
preprocess(fixed, "mri"),
finetune_steps=None)

phi_AB_vector = net.phi_AB_vectorfield

# Compute the reference
def preprocess_in_torch(img):
im_min, im_max = torch.min(img), np.quantile(img.numpy().flatten(), 0.99) #torch.quantile(img.view(-1), 0.99)
img = torch.clip(img, im_min, im_max)
img = (img-im_min) / (im_max-im_min)
return img

shape = net.identity_map.shape

fixed = torch.from_numpy(np.array(itk.imread(fixed_path), dtype=np.float32)).unsqueeze(0).unsqueeze(0)
fixed_in_net = preprocess_in_torch(fixed)
fixed_in_net = F.interpolate(fixed_in_net, shape[2:], mode='trilinear', align_corners=False)

moving = torch.Tensor(np.array(itk.imread(moving_path), dtype=np.float32)).unsqueeze(0).unsqueeze(0)
moving_in_net = preprocess_in_torch(moving)
moving_in_net = F.interpolate(moving_in_net, shape[2:], mode='trilinear', align_corners=False)

net = get_unigradicon()
with torch.no_grad():
net(moving_in_net.cuda(), fixed_in_net.cuda())

self.assertLess(
torch.mean(torch.abs(phi_AB_vector - net.phi_AB_vectorfield)), 1e-5
)


def test_preprocessing_mri(self):
img_path = f"{self.test_data_dir}/brain_test_data/8_T1w_acpc_dc_restore_brain.nii.gz"

# Run ITK interface
img = itk.imread(img_path)
img_preprocessed = preprocess(img, "mri")

# Compute the reference
def preprocess_in_torch(img):
im_min, im_max = torch.min(img), np.quantile(img.numpy().flatten(), 0.99) #torch.quantile(img.view(-1), 0.99)
img = torch.clip(img, im_min, im_max)
img = (img-im_min) / (im_max-im_min)
return img
reference = preprocess_in_torch(torch.Tensor(np.array(img, dtype=np.float32))).numpy()

self.assertLess(
np.mean(np.abs(reference - img_preprocessed)), 1e-5
)

def test_preprocessing_ct(self):
img_path = f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz"

# Run ITK interface
img = itk.imread(img_path)
img_preprocessed = preprocess(img, "ct")

# Compute the reference
def preprocess_in_torch(img):
im_min, im_max = -1000, 1000
img = torch.clip(img, im_min, im_max)
img = (img-im_min) / (im_max-im_min)
return img
reference = preprocess_in_torch(torch.Tensor(np.array(img, dtype=np.float32))).numpy()

self.assertLess(
np.mean(np.abs(reference - img_preprocessed)), 1e-5
)

def test_itk_registration(self):
net = get_unigradicon()

image_exp = itk.imread(
str(
self.test_data_dir
/ "lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz"
)
)
image_insp = itk.imread(
str(
self.test_data_dir
/ "lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz"
)
)
image_exp_seg = itk.imread(
str(
self.test_data_dir
/ "lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz"
)
)
image_insp_seg = itk.imread(
str(
self.test_data_dir
/ "lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz"
)
)

image_insp_preprocessed = preprocess(image_insp, "ct", image_insp_seg)
image_exp_preprocessed = preprocess(image_exp, "ct", image_exp_seg)

phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(
net, image_insp_preprocessed, image_exp_preprocessed, finetune_steps=None
)

assert isinstance(phi_AB, itk.CompositeTransform)

insp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_iBH_xyz_r1.txt"
)
)
exp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_eBH_xyz_r1.txt"
)
)
dists = []
for i in range(len(insp_points)):
px, py = (
exp_points[i],
np.array(phi_BA.TransformPoint(tuple(insp_points[i]))),
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
self.assertLess(np.mean(dists), 1.7)

dists = []
for i in range(len(insp_points)):
px, py = (
insp_points[i],
np.array(phi_AB.TransformPoint(tuple(exp_points[i]))),
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
self.assertLess(np.mean(dists), 2.1)

def test_itk_warp(self):
fixed_path = f"{self.test_data_dir}/brain_test_data/8_T1w_acpc_dc_restore_brain.nii.gz"
moving_path = f"{self.test_data_dir}/brain_test_data/2_T1w_acpc_dc_restore_brain.nii.gz"

# Run ITK interface
fixed = itk.imread(fixed_path)
moving = itk.imread(moving_path)

net = get_unigradicon()

phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair(
net,
preprocess(moving, "mri"),
preprocess(fixed, "mri"),
finetune_steps=None)

interpolator = itk.LinearInterpolateImageFunction.New(moving)
warped_moving_image = np.array(itk.resample_image_filter(
preprocess(moving, "mri"),
transform=phi_AB,
interpolator=interpolator,
use_reference_image=True,
reference_image=fixed
))

reference = F.interpolate(net.warped_image_A, size=warped_moving_image.shape, mode='trilinear', align_corners=False)[0,0].cpu().numpy()

from icon_registration.losses import NCC
diff = NCC()(torch.Tensor(warped_moving_image).unsqueeze(0).unsqueeze(0), torch.Tensor(reference).unsqueeze(0).unsqueeze(0))
self.assertLess(
diff, 5e-3
)

0 comments on commit 42b10fe

Please sign in to comment.