From ab35a93459ac9e2b1848a1db5c4374eb1e6bafe6 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Sun, 16 Jun 2024 19:46:13 -0400 Subject: [PATCH] add basic tests and support for arbitrarily batched positions in 3D (#4) --- src/torch_subpixel_crop/subpixel_crop_3d.py | 8 +++-- tests/test_subpixel_crop_2d.py | 39 +++++++++++++++++++++ tests/test_subpixel_crop_3d.py | 39 +++++++++++++++++++++ tests/test_torch_subimage.py | 2 -- 4 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 tests/test_subpixel_crop_2d.py create mode 100644 tests/test_subpixel_crop_3d.py delete mode 100644 tests/test_torch_subimage.py diff --git a/src/torch_subpixel_crop/subpixel_crop_3d.py b/src/torch_subpixel_crop/subpixel_crop_3d.py index 933edf1..9972d10 100644 --- a/src/torch_subpixel_crop/subpixel_crop_3d.py +++ b/src/torch_subpixel_crop/subpixel_crop_3d.py @@ -28,17 +28,18 @@ def subpixel_crop_3d( image: torch.Tensor `(d, h, w)` array containing the volume. positions: torch.Tensor - `(b, 3)` array of coordinates for patch centers. + `(..., 3)` array of coordinates for patch centers. sidelength: int Sidelength of cubic patches extracted from `image`. Returns ------- patches: torch.Tensor - `(b, sidelength, sidelength, sidelength)` array of cropped regions from `volume` + `(..., sidelength, sidelength, sidelength)` array of cropped regions from `volume` with their centers at `positions`. """ d, h, w = image.shape + positions, ps = einops.pack([positions], pattern='* zyx') b, _ = positions.shape # find integer positions and shifts to be applied @@ -68,4 +69,7 @@ def subpixel_crop_3d( # phase shift to center images patches = fourier_shift_image_3d(image=patches, shifts=shifts) + + # unpack + [patches] = einops.unpack(patches, pattern='* t h w', packed_shapes=ps) return patches diff --git a/tests/test_subpixel_crop_2d.py b/tests/test_subpixel_crop_2d.py new file mode 100644 index 0000000..41fc31a --- /dev/null +++ b/tests/test_subpixel_crop_2d.py @@ -0,0 +1,39 @@ +import torch + +from torch_subpixel_crop import subpixel_crop_2d + + +def test_subpixel_crop_single_2d(): + image = torch.zeros((10, 10)) + image[4:6, 4:6] = 1 + + cropped_image = subpixel_crop_2d( + image=image, + positions=torch.tensor([5, 5]).float(), + sidelength=4 + ) + assert cropped_image.shape == (4, 4) + + expected = torch.zeros((4, 4)) + expected[1:3, 1:3] = 1 + assert torch.allclose(cropped_image, expected) + + +def test_subpixel_crop_multi_2d(): + image = torch.zeros((10, 10)) + image[4:6, 4:6] = 1 + + cropped_image = subpixel_crop_2d( + image=image, + positions=torch.tensor([[4, 4], [5, 5]]).float(), + sidelength=4 + ) + assert cropped_image.shape == (2, 4, 4) + + expected_0 = torch.zeros((4, 4)) + expected_0[2:4, 2:4] = 1 + assert torch.allclose(cropped_image[0], expected_0) + + expected_1 = torch.zeros((4, 4)) + expected_1[1:3, 1:3] = 1 + assert torch.allclose(cropped_image[1], expected_1) diff --git a/tests/test_subpixel_crop_3d.py b/tests/test_subpixel_crop_3d.py new file mode 100644 index 0000000..af97668 --- /dev/null +++ b/tests/test_subpixel_crop_3d.py @@ -0,0 +1,39 @@ +import torch + +from torch_subpixel_crop import subpixel_crop_3d + + +def test_subpixel_crop_single_3d(): + image = torch.zeros((10, 10, 10)) + image[4:6, 4:6, 4:6] = 1 + + cropped_image = subpixel_crop_3d( + image=image, + positions=torch.tensor([5, 5, 5]).float(), + sidelength=4 + ) + assert cropped_image.shape == (4, 4, 4) + + expected = torch.zeros((4, 4, 4)) + expected[1:3, 1:3, 1:3] = 1 + assert torch.allclose(cropped_image, expected) + + +def test_subpixel_crop_multi_3d(): + image = torch.zeros((10, 10, 10)) + image[4:6, 4:6, 4:6] = 1 + + cropped_image = subpixel_crop_3d( + image=image, + positions=torch.tensor([[4, 4, 4], [5, 5, 5]]).float(), + sidelength=4 + ) + assert cropped_image.shape == (2, 4, 4, 4) + + expected_0 = torch.zeros((4, 4, 4)) + expected_0[2:4, 2:4, 2:4] = 1 + assert torch.allclose(cropped_image[0], expected_0) + + expected_1 = torch.zeros((4, 4, 4)) + expected_1[1:3, 1:3, 1:3] = 1 + assert torch.allclose(cropped_image[1], expected_1) diff --git a/tests/test_torch_subimage.py b/tests/test_torch_subimage.py deleted file mode 100644 index 363b3e2..0000000 --- a/tests/test_torch_subimage.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_something(): - pass