diff --git a/src/torch_fourier_rescale/__init__.py b/src/torch_fourier_rescale/__init__.py index bdf41ea..556d4aa 100644 --- a/src/torch_fourier_rescale/__init__.py +++ b/src/torch_fourier_rescale/__init__.py @@ -1,2 +1,2 @@ -from .fourier_rescale_2d import fourier_rescale_2d -from .fourier_rescale_3d import fourier_rescale_3d \ No newline at end of file +from .fourier_rescale_2d import fourier_rescale_2d, fourier_rescale_rfft_2d +from .fourier_rescale_3d import fourier_rescale_3d, fourier_rescale_rfft_3d diff --git a/src/torch_fourier_rescale/fourier_rescale_2d.py b/src/torch_fourier_rescale/fourier_rescale_2d.py index b8ad77a..1b01bb6 100644 --- a/src/torch_fourier_rescale/fourier_rescale_2d.py +++ b/src/torch_fourier_rescale/fourier_rescale_2d.py @@ -45,7 +45,7 @@ def fourier_rescale_2d( dft = torch.fft.fftshift(dft, dim=(-2,)) # Fourier pad/crop - dft, new_nyquist = _rescale_rfft_2d( + dft, new_nyquist = fourier_rescale_rfft_2d( dft=dft, image_shape=image.shape[-2:], source_spacing=source_spacing, @@ -65,7 +65,7 @@ def fourier_rescale_2d( return rescaled_image, tuple(new_spacing) -def _rescale_rfft_2d( +def fourier_rescale_rfft_2d( dft: torch.Tensor, image_shape: tuple[int, int], source_spacing: tuple[float, float], diff --git a/src/torch_fourier_rescale/fourier_rescale_3d.py b/src/torch_fourier_rescale/fourier_rescale_3d.py index 4dc4fc1..2a75dad 100644 --- a/src/torch_fourier_rescale/fourier_rescale_3d.py +++ b/src/torch_fourier_rescale/fourier_rescale_3d.py @@ -43,7 +43,7 @@ def fourier_rescale_3d( dft = torch.fft.fftshift(dft, dim=(-3, -2)) # Fourier pad/crop - dft, new_nyquist = _rescale_rfft_3d( + dft, new_nyquist = fourier_rescale_rfft_3d( dft=dft, image_shape=image.shape[-3:], source_spacing=source_spacing, @@ -63,7 +63,7 @@ def fourier_rescale_3d( return rescaled_image, tuple(new_spacing) -def _rescale_rfft_3d( +def fourier_rescale_rfft_3d( dft: torch.Tensor, image_shape: tuple[int, int, int], source_spacing: tuple[float, float, float],