Skip to content

Commit

Permalink
refactor: remove old common lines code
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet committed Nov 16, 2024
1 parent 19064b2 commit f179111
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 43 deletions.
9 changes: 4 additions & 5 deletions src/tttsa/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch_fourier_slice import project_2d_to_1d

from .affine import affine_transform_2d
from .projection import common_lines_projection
from .transformations import R_2d, T_2d, stretch_matrix


Expand Down Expand Up @@ -70,10 +69,10 @@ def optimize_tilt_axis_angle(
coarse_aligned_masked = aligned_ts * coarse_alignment_mask

# generate a weighting for the common line ROI by projecting the mask
mask_weights = common_lines_projection(
einops.rearrange(coarse_alignment_mask, "h w -> 1 h w"),
0.0, # angle does not matter
)
mask_weights = project_2d_to_1d(
coarse_alignment_mask,
torch.eye(2).to(coarse_alignment_mask.device), # angle does not matter
).squeeze() # remove starting empty dimension
mask_weights /= mask_weights.max() # normalise to 0 and 1

# optimize tilt axis angle
Expand Down
3 changes: 1 addition & 2 deletions src/tttsa/projection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Projection of images and volumes."""

from .project_real import common_lines_projection, tomogram_reprojection
from .project_real import tomogram_reprojection

__all__ = [
"common_lines_projection",
"tomogram_reprojection",
]
37 changes: 1 addition & 36 deletions src/tttsa/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,15 @@
from torch_grid_utils import coordinate_grid
from torch_image_lerp import insert_into_image_2d

from tttsa.affine import affine_transform_2d
from tttsa.transformations import (
R_2d,
T_2d,
projection_model_to_projection_matrix,
)
from tttsa.utils import dft_center, homogenise_coordinates
from tttsa.utils import homogenise_coordinates

# update shift
PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X]


def common_lines_projection(
images: torch.Tensor,
tilt_axis_angles: torch.Tensor,
# this might as well takes shifts
) -> torch.Tensor:
"""Predict a projection from an intermediate reconstruction.
For now only assumes to project with a single matrix, but should also work for
sets of matrices.
"""
device = images.device
image_dimensions = images.shape[-2:]

# TODO pad image if not square

image_center = dft_center(image_dimensions, rfft=False, fftshifted=True)

# time for real space projection
s0 = T_2d(-image_center)
r0 = R_2d(tilt_axis_angles, yx=True)
s1 = T_2d(image_center)
# invert because the tilt axis angle is forward in the sample projection model
M = torch.linalg.inv(s1 @ r0 @ s0).to(device)

rotated = affine_transform_2d(
images,
M,
)
projections = rotated.mean(axis=-1).squeeze()
return projections


def tomogram_reprojection(
tomogram: torch.Tensor,
tilt_image_dimensions: Tuple[int, int],
Expand Down

0 comments on commit f179111

Please sign in to comment.