From f179111f50f7338b95a20ca4ef746724a05aac03 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Sat, 16 Nov 2024 15:49:53 +0100 Subject: [PATCH] refactor: remove old common lines code --- src/tttsa/optimizers.py | 9 +++---- src/tttsa/projection/__init__.py | 3 +-- src/tttsa/projection/project_real.py | 37 +--------------------------- 3 files changed, 6 insertions(+), 43 deletions(-) diff --git a/src/tttsa/optimizers.py b/src/tttsa/optimizers.py index a78dbeb..6045c8d 100644 --- a/src/tttsa/optimizers.py +++ b/src/tttsa/optimizers.py @@ -6,7 +6,6 @@ from torch_fourier_slice import project_2d_to_1d from .affine import affine_transform_2d -from .projection import common_lines_projection from .transformations import R_2d, T_2d, stretch_matrix @@ -70,10 +69,10 @@ def optimize_tilt_axis_angle( coarse_aligned_masked = aligned_ts * coarse_alignment_mask # generate a weighting for the common line ROI by projecting the mask - mask_weights = common_lines_projection( - einops.rearrange(coarse_alignment_mask, "h w -> 1 h w"), - 0.0, # angle does not matter - ) + mask_weights = project_2d_to_1d( + coarse_alignment_mask, + torch.eye(2).to(coarse_alignment_mask.device), # angle does not matter + ).squeeze() # remove starting empty dimension mask_weights /= mask_weights.max() # normalise to 0 and 1 # optimize tilt axis angle diff --git a/src/tttsa/projection/__init__.py b/src/tttsa/projection/__init__.py index d054365..e7aef33 100644 --- a/src/tttsa/projection/__init__.py +++ b/src/tttsa/projection/__init__.py @@ -1,8 +1,7 @@ """Projection of images and volumes.""" -from .project_real import common_lines_projection, tomogram_reprojection +from .project_real import tomogram_reprojection __all__ = [ - "common_lines_projection", "tomogram_reprojection", ] diff --git a/src/tttsa/projection/project_real.py b/src/tttsa/projection/project_real.py index 69f3948..349336b 100644 --- a/src/tttsa/projection/project_real.py +++ b/src/tttsa/projection/project_real.py @@ -9,50 +9,15 @@ from torch_grid_utils import coordinate_grid from torch_image_lerp import insert_into_image_2d -from tttsa.affine import affine_transform_2d from tttsa.transformations import ( - R_2d, - T_2d, projection_model_to_projection_matrix, ) -from tttsa.utils import dft_center, homogenise_coordinates +from tttsa.utils import homogenise_coordinates # update shift PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] -def common_lines_projection( - images: torch.Tensor, - tilt_axis_angles: torch.Tensor, - # this might as well takes shifts -) -> torch.Tensor: - """Predict a projection from an intermediate reconstruction. - - For now only assumes to project with a single matrix, but should also work for - sets of matrices. - """ - device = images.device - image_dimensions = images.shape[-2:] - - # TODO pad image if not square - - image_center = dft_center(image_dimensions, rfft=False, fftshifted=True) - - # time for real space projection - s0 = T_2d(-image_center) - r0 = R_2d(tilt_axis_angles, yx=True) - s1 = T_2d(image_center) - # invert because the tilt axis angle is forward in the sample projection model - M = torch.linalg.inv(s1 @ r0 @ s0).to(device) - - rotated = affine_transform_2d( - images, - M, - ) - projections = rotated.mean(axis=-1).squeeze() - return projections - - def tomogram_reprojection( tomogram: torch.Tensor, tilt_image_dimensions: Tuple[int, int],