Skip to content

Commit

Permalink
add basic tests and support for arbitrarily batched positions in 3D (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt authored Jun 16, 2024
1 parent 1253df6 commit ab35a93
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/torch_subpixel_crop/subpixel_crop_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ def subpixel_crop_3d(
image: torch.Tensor
`(d, h, w)` array containing the volume.
positions: torch.Tensor
`(b, 3)` array of coordinates for patch centers.
`(..., 3)` array of coordinates for patch centers.
sidelength: int
Sidelength of cubic patches extracted from `image`.
Returns
-------
patches: torch.Tensor
`(b, sidelength, sidelength, sidelength)` array of cropped regions from `volume`
`(..., sidelength, sidelength, sidelength)` array of cropped regions from `volume`
with their centers at `positions`.
"""
d, h, w = image.shape
positions, ps = einops.pack([positions], pattern='* zyx')
b, _ = positions.shape

# find integer positions and shifts to be applied
Expand Down Expand Up @@ -68,4 +69,7 @@ def subpixel_crop_3d(

# phase shift to center images
patches = fourier_shift_image_3d(image=patches, shifts=shifts)

# unpack
[patches] = einops.unpack(patches, pattern='* t h w', packed_shapes=ps)
return patches
39 changes: 39 additions & 0 deletions tests/test_subpixel_crop_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

from torch_subpixel_crop import subpixel_crop_2d


def test_subpixel_crop_single_2d():
image = torch.zeros((10, 10))
image[4:6, 4:6] = 1

cropped_image = subpixel_crop_2d(
image=image,
positions=torch.tensor([5, 5]).float(),
sidelength=4
)
assert cropped_image.shape == (4, 4)

expected = torch.zeros((4, 4))
expected[1:3, 1:3] = 1
assert torch.allclose(cropped_image, expected)


def test_subpixel_crop_multi_2d():
image = torch.zeros((10, 10))
image[4:6, 4:6] = 1

cropped_image = subpixel_crop_2d(
image=image,
positions=torch.tensor([[4, 4], [5, 5]]).float(),
sidelength=4
)
assert cropped_image.shape == (2, 4, 4)

expected_0 = torch.zeros((4, 4))
expected_0[2:4, 2:4] = 1
assert torch.allclose(cropped_image[0], expected_0)

expected_1 = torch.zeros((4, 4))
expected_1[1:3, 1:3] = 1
assert torch.allclose(cropped_image[1], expected_1)
39 changes: 39 additions & 0 deletions tests/test_subpixel_crop_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

from torch_subpixel_crop import subpixel_crop_3d


def test_subpixel_crop_single_3d():
image = torch.zeros((10, 10, 10))
image[4:6, 4:6, 4:6] = 1

cropped_image = subpixel_crop_3d(
image=image,
positions=torch.tensor([5, 5, 5]).float(),
sidelength=4
)
assert cropped_image.shape == (4, 4, 4)

expected = torch.zeros((4, 4, 4))
expected[1:3, 1:3, 1:3] = 1
assert torch.allclose(cropped_image, expected)


def test_subpixel_crop_multi_3d():
image = torch.zeros((10, 10, 10))
image[4:6, 4:6, 4:6] = 1

cropped_image = subpixel_crop_3d(
image=image,
positions=torch.tensor([[4, 4, 4], [5, 5, 5]]).float(),
sidelength=4
)
assert cropped_image.shape == (2, 4, 4, 4)

expected_0 = torch.zeros((4, 4, 4))
expected_0[2:4, 2:4, 2:4] = 1
assert torch.allclose(cropped_image[0], expected_0)

expected_1 = torch.zeros((4, 4, 4))
expected_1[1:3, 1:3, 1:3] = 1
assert torch.allclose(cropped_image[1], expected_1)
2 changes: 0 additions & 2 deletions tests/test_torch_subimage.py

This file was deleted.

0 comments on commit ab35a93

Please sign in to comment.