From ecaa1bc0a92d0f17cc80c51734ebe46e045ed860 Mon Sep 17 00:00:00 2001 From: Marten Chaillet <58044494+McHaillet@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:42:03 +0100 Subject: [PATCH] use cryotypes for projection model backend (#18) * feat: initialize cryotypes projection model in usage example * feat: rely everywhere on cryotypes ProjectionModel * refactor: add utility functions for matrix generation to facilitate projection_model to matrix conversion * refactor: remoeve image stretch function * fix: set cryotypes version to 0.2 --- examples/usage.py | 37 ++++--- pyproject.toml | 1 + src/tttsa/affine/__init__.py | 3 +- src/tttsa/affine/affine_transform.py | 26 +---- .../filtered_back_projection.py | 54 ++++----- src/tttsa/coarse_align.py | 57 +++++----- src/tttsa/optimizers.py | 44 ++++---- src/tttsa/projection/project_real.py | 28 ++--- src/tttsa/projection_matching.py | 41 +++---- src/tttsa/transformations.py | 95 ++++++++++++++-- src/tttsa/tttsa.py | 103 ++++++++++-------- tests/affine/test_affine_transform.py | 8 +- 12 files changed, 282 insertions(+), 215 deletions(-) diff --git a/examples/usage.py b/examples/usage.py index 59b5f76..eabe2a0 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -7,6 +7,8 @@ import numpy as np import pooch import torch +from cryotypes.projectionmodel import ProjectionModel +from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL from torch_fourier_rescale import fourier_rescale_2d from torch_subpixel_crop import subpixel_crop_2d @@ -26,10 +28,10 @@ IMAGE_FILE = Path(GOODBOY.fetch("tomo200528_107.st", progressbar=True)) with open(Path(GOODBOY.fetch("tomo200528_107.rawtlt"))) as f: - STAGE_TILT_ANGLE_PRIORS = torch.tensor([float(x) for x in f.readlines()]) + STAGE_TILT_ANGLE_PRIORS = [float(x) for x in f.readlines()] IMAGE_PIXEL_SIZE = 1.724 # this angle is assumed to be a clockwise forward rotation after projecting the sample -TILT_AXIS_ANGLE_PRIOR = torch.tensor(-88.7) +TILT_AXIS_ANGLE_PRIOR = -88.7 ALIGNMENT_PIXEL_SIZE = IMAGE_PIXEL_SIZE * 8 ALIGN_Z = int(1600 / ALIGNMENT_PIXEL_SIZE) # number is in A RECON_Z = int(2400 / ALIGNMENT_PIXEL_SIZE) @@ -41,6 +43,19 @@ # Set the device for running DEVICE = "cuda:0" +# Initialize the projection-model prior +projection_model_prior = ProjectionModel( + { + PMDL.ROTATION_Z: TILT_AXIS_ANGLE_PRIOR, + PMDL.ROTATION_Y: STAGE_TILT_ANGLE_PRIORS, + PMDL.ROTATION_X: 0.0, + PMDL.SHIFT_X: 0.0, + PMDL.SHIFT_Y: 0.0, + PMDL.EXPERIMENT_ID: IMAGE_FILE.stem, + PMDL.PIXEL_SPACING: ALIGNMENT_PIXEL_SIZE, + PMDL.SOURCE: IMAGE_FILE.name, + } +) tilt_series = torch.as_tensor(mrcfile.read(IMAGE_FILE)) @@ -69,25 +84,21 @@ size = min(h, w) # Move all the input to the device -tilt_angles, tilt_axis_angles, shifts = tilt_series_alignment( +projection_model_optimized = tilt_series_alignment( tilt_series.to(DEVICE), - STAGE_TILT_ANGLE_PRIORS, - TILT_AXIS_ANGLE_PRIOR, + projection_model_prior, ALIGN_Z, find_tilt_angle_offset=False, ) -final, aligned_ts = filtered_back_projection_3d( +final = filtered_back_projection_3d( tilt_series, (RECON_Z, size, size), - tilt_angles, - tilt_axis_angles, - shifts, + projection_model_optimized, weighting=WEIGHTING, object_diameter=OBJECT_DIAMETER, ) final = final.to("cpu") -aligned_ts = aligned_ts.to("cpu") OUTPUT_DIR.mkdir(exist_ok=True) mrcfile.write( @@ -96,9 +107,3 @@ voxel_size=ALIGNMENT_PIXEL_SIZE, overwrite=True, ) -mrcfile.write( - OUTPUT_DIR.joinpath(IMAGE_FILE.with_suffix(".ali").name), - aligned_ts.detach().numpy().astype(np.float32), - voxel_size=ALIGNMENT_PIXEL_SIZE, - overwrite=True, -) diff --git a/pyproject.toml b/pyproject.toml index acf962c..39cb1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "torch-cubic-spline-grids", "torch-fourier-shift", "torch-image-lerp", + "cryotypes == 0.2", "einops", "numpy", "scipy", diff --git a/src/tttsa/affine/__init__.py b/src/tttsa/affine/__init__.py index 647440e..facd65e 100644 --- a/src/tttsa/affine/__init__.py +++ b/src/tttsa/affine/__init__.py @@ -1,9 +1,8 @@ """2D and 3D affine transform functionality.""" -from .affine_transform import affine_transform_2d, affine_transform_3d, stretch_image +from .affine_transform import affine_transform_2d, affine_transform_3d __all__ = [ "affine_transform_2d", "affine_transform_3d", - "stretch_image", ] diff --git a/src/tttsa/affine/affine_transform.py b/src/tttsa/affine/affine_transform.py index 0290dc6..89155b5 100644 --- a/src/tttsa/affine/affine_transform.py +++ b/src/tttsa/affine/affine_transform.py @@ -7,31 +7,7 @@ import torch.nn.functional as F from torch_grid_utils import coordinate_grid -from tttsa.transformations import R_2d, T_2d -from tttsa.utils import array_to_grid_sample, dft_center, homogenise_coordinates - - -def stretch_image( - image: torch.Tensor, - stretch: torch.Tensor | float, - tilt_axis_angle: torch.Tensor | float, -) -> torch.Tensor: - """Utility function for stretching an image on the tilt axis.""" - image_center = dft_center(image.shape, rfft=False, fftshifted=True) - # construct matrix - s0 = T_2d(-image_center) - r_forward = R_2d(tilt_axis_angle, yx=True) - r_backward = torch.linalg.inv(r_forward) - m_stretch = torch.eye(3) - m_stretch[1, 1] = stretch # this is a shear matrix - s1 = T_2d(image_center) - m_affine = s1 @ r_forward @ m_stretch @ r_backward @ s0 - # transform image - stretched = affine_transform_2d( - image, - m_affine, - ) - return stretched +from tttsa.utils import array_to_grid_sample, homogenise_coordinates def affine_transform_2d( diff --git a/src/tttsa/back_projection/filtered_back_projection.py b/src/tttsa/back_projection/filtered_back_projection.py index fa23dc2..992d23e 100644 --- a/src/tttsa/back_projection/filtered_back_projection.py +++ b/src/tttsa/back_projection/filtered_back_projection.py @@ -5,19 +5,25 @@ import einops import torch import torch.nn.functional as F +from cryotypes.projectionmodel import ProjectionModel +from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL from torch_grid_utils import coordinate_grid from tttsa.affine import affine_transform_2d -from tttsa.transformations import R_2d, Ry, T, T_2d -from tttsa.utils import array_to_grid_sample, dft_center, homogenise_coordinates +from tttsa.transformations import ( + projection_model_to_backproject_matrix, + projection_model_to_tsa_matrix, +) +from tttsa.utils import array_to_grid_sample, homogenise_coordinates + +# update shift +PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] def filtered_back_projection_3d( tilt_series: torch.Tensor, tomogram_dimensions: Tuple[int, int, int], - tilt_angles: torch.Tensor, - tilt_axis_angles: torch.Tensor, - shifts: torch.Tensor, + projection_model: ProjectionModel, weighting: str = "exact", object_diameter: float | None = None, ) -> torch.Tensor: @@ -43,19 +49,12 @@ def filtered_back_projection_3d( n_tilts, h, w = tilt_series.shape # for simplicity assume square images tilt_image_dimensions = (h, w) transformed_image_dimensions = tomogram_dimensions[-2:] - tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True) - tilt_image_center = dft_center(tilt_image_dimensions, rfft=False, fftshifted=True) - transformed_image_center = dft_center( - transformed_image_dimensions, rfft=False, fftshifted=True - ) _, filter_size = transformed_image_dimensions # generate the 2d alignment affine matrix - s0 = T_2d(-tilt_image_center) - r0 = R_2d(tilt_axis_angles, yx=True) - s1 = T_2d(-shifts) - s2 = T_2d(transformed_image_center) - M = torch.linalg.inv(s2 @ s1 @ r0 @ s0).to(device) + M = projection_model_to_tsa_matrix( + projection_model, tilt_image_dimensions, transformed_image_dimensions + ).to(device) aligned_ts = affine_transform_2d( tilt_series, @@ -69,7 +68,7 @@ def filtered_back_projection_3d( raise ValueError( "Calculation of exact weighting requires an object " "diameter." ) - if len(tilt_angles) == 1: + if n_tilts == 1: # set explicitly as tensor to ensure correct typing filters = torch.tensor(1.0, device=device) else: # slice_width could be provided as a function argument it can be @@ -83,6 +82,7 @@ def filtered_back_projection_3d( / filter_size, "q -> 1 1 q", ) + tilt_angles = torch.as_tensor(projection_model[PMDL.ROTATION_Y]) sampling = torch.sin( torch.deg2rad( torch.abs(einops.rearrange(tilt_angles, "n -> n 1") - tilt_angles) @@ -124,15 +124,17 @@ def filtered_back_projection_3d( if len(weighted.shape) == 2: # rfftn gets rid of batch dimension: add it back weighted = einops.rearrange(weighted, "h w -> 1 h w") - # create recon from weighted-aligned ts - s0 = T(-tomogram_center) - r0 = Ry(tilt_angles, zyx=True) - s1 = T(tomogram_center) - # This would actually be a double linalg.inv. First for the inverse of the - # forward projection alignment model. The second for the affine transform. - # It could be more logical to use affine_transform_3d, but it requires - # recalculation of the grid for every iteration. - M = einops.rearrange(s1 @ r0 @ s0, "... i j -> ... 1 1 i j").to(device) + # We need to lingalg.inv the matrix as the affine transform is done inside + # this function. It could be more logical to use affine_transform_3d (and do + # inversion inside) but it requires recalculation of the grid for every iteration. + M = einops.rearrange( + torch.linalg.inv( + projection_model_to_backproject_matrix( + projection_model, tomogram_dimensions + ) + ), + "... i j -> ... 1 1 i j", + ).to(device) reconstruction = torch.zeros( tomogram_dimensions, dtype=torch.float32, device=device @@ -156,4 +158,4 @@ def filtered_back_projection_3d( mode="bilinear", ) ) - return reconstruction, aligned_ts + return reconstruction diff --git a/src/tttsa/coarse_align.py b/src/tttsa/coarse_align.py index 1f5c855..0a31067 100644 --- a/src/tttsa/coarse_align.py +++ b/src/tttsa/coarse_align.py @@ -1,9 +1,11 @@ """Coarse tilt-series alignment functions, also with stretching.""" +import einops import torch -from .affine import stretch_image +from .affine import affine_transform_2d from .alignment import find_image_shift +from .transformations import stretch_matrix def coarse_align( @@ -12,20 +14,25 @@ def coarse_align( mask: torch.Tensor, ) -> torch.Tensor: """Find coarse shifts of images without stretching along tilt axis.""" - shifts = torch.zeros((len(tilt_series), 2), dtype=torch.float32) + n_tilts = len(tilt_series) + shifts = torch.zeros((n_tilts, 2), dtype=torch.float32) + ts_masked = tilt_series * mask + ts_masked -= einops.reduce(ts_masked, "tilt h w -> tilt 1 1", reduction="mean") + ts_masked /= torch.std(ts_masked, dim=(-2, -1), keepdim=True) + # find coarse alignment for negative tilts current_shift = torch.zeros(2) for i in range(reference_tilt_id, 0, -1): - shift = find_image_shift(tilt_series[i] * mask, tilt_series[i - 1] * mask) + shift = find_image_shift(ts_masked[i], ts_masked[i - 1]) current_shift += shift shifts[i - 1] = current_shift # find coarse alignment positive tilts current_shift = torch.zeros(2) - for i in range(reference_tilt_id, tilt_series.shape[0] - 1, 1): + for i in range(reference_tilt_id, n_tilts - 1, 1): shift = find_image_shift( - tilt_series[i] * mask, - tilt_series[i + 1] * mask, + ts_masked[i], + ts_masked[i + 1], ) current_shift += shift shifts[i + 1] = current_shift @@ -40,21 +47,20 @@ def stretch_align( tilt_axis_angles: torch.Tensor, ) -> torch.Tensor: """Find coarse shifts of images while stretching each pair along the tilt axis.""" - shifts = torch.zeros((len(tilt_series), 2), dtype=torch.float32) + n_tilts, h, w = tilt_series.shape + tilt_image_dimensions = (h, w) + shifts = torch.zeros((n_tilts, 2), dtype=torch.float32) + cos_ta = torch.cos(torch.deg2rad(tilt_angles)) + # find coarse alignment for negative tilts current_shift = torch.zeros(2) for i in range(reference_tilt_id, 0, -1): - scale_factor = torch.cos(torch.deg2rad(tilt_angles[i : i + 1])) / torch.cos( - torch.deg2rad(tilt_angles[i - 1 : i]) - ) - stretched = ( - stretch_image( - tilt_series[i - 1], - scale_factor, - tilt_axis_angles[i - 1], - ) - * mask + M = stretch_matrix( + tilt_image_dimensions, + tilt_axis_angles[i - 1], + scale_factor=cos_ta[i : i + 1] / cos_ta[i - 1 : i], ) + stretched = affine_transform_2d(tilt_series[i - 1], M) * mask stretched = (stretched - stretched.mean()) / stretched.std() raw = tilt_series[i] * mask raw = (raw - raw.mean()) / raw.std() @@ -63,18 +69,13 @@ def stretch_align( shifts[i - 1] = current_shift # find coarse alignment positive tilts current_shift = torch.zeros(2) - for i in range(reference_tilt_id, tilt_series.shape[0] - 1, 1): - scale_factor = torch.cos(torch.deg2rad(tilt_angles[i : i + 1])) / torch.cos( - torch.deg2rad(tilt_angles[i + 1 : i + 2]) - ) - stretched = ( - stretch_image( - tilt_series[i + 1], - scale_factor, - tilt_axis_angles[i + 1], - ) - * mask + for i in range(reference_tilt_id, n_tilts - 1, 1): + M = stretch_matrix( + tilt_image_dimensions, + tilt_axis_angles[i + 1], + scale_factor=cos_ta[i : i + 1] / cos_ta[i + 1 : i + 2], ) + stretched = affine_transform_2d(tilt_series[i + 1], M) * mask stretched = (stretched - stretched.mean()) / stretched.std() raw = tilt_series[i] * mask raw = (raw - raw.mean()) / raw.std() diff --git a/src/tttsa/optimizers.py b/src/tttsa/optimizers.py index a386957..f76d96f 100644 --- a/src/tttsa/optimizers.py +++ b/src/tttsa/optimizers.py @@ -4,9 +4,9 @@ import torch from torch_cubic_spline_grids import CubicBSplineGrid1d -from .affine import affine_transform_2d, stretch_image +from .affine import affine_transform_2d from .projection import common_lines_projection -from .transformations import T_2d +from .transformations import T_2d, stretch_matrix def stretch_loss( @@ -19,20 +19,19 @@ def stretch_loss( ) -> torch.Tensor: """Find coarse shifts of images while stretching each pair along the tilt axis.""" device = tilt_series.device + n_tilts, h, w = tilt_series.shape + tilt_image_dimensions = (h, w) + cos_ta = torch.cos(torch.deg2rad(tilt_angles)) + sq_diff = torch.tensor(0.0, device=device) for i in range(reference_tilt_id, 0, -1): - scale_factor = torch.cos(torch.deg2rad(tilt_angles[i : i + 1])) / torch.cos( - torch.deg2rad(tilt_angles[i - 1 : i]) - ) - stretched = stretch_image( # stretch image i - 1 - tilt_series[i - 1], - scale_factor, + # multiply stretch matrix by shift for full alignment + M = T_2d(shifts[i - 1] - shifts[i]) @ stretch_matrix( + tilt_image_dimensions, tilt_axis_angles[i - 1], - ) - stretched = affine_transform_2d( # shift to the same position as i - stretched, - T_2d(shifts[i - 1] - shifts[i]), - ) + scale_factor=cos_ta[i : i + 1] / cos_ta[i - 1 : i], + ) # slicing cos_ta ensure gradient calculation + stretched = affine_transform_2d(tilt_series[i - 1], M) non_empty = (stretched != 0) * 1.0 correlation_mask = non_empty * mask stretched = stretched * correlation_mask @@ -42,19 +41,14 @@ def stretch_loss( sq_diff = sq_diff + ((ref - stretched) ** 2).sum() / stretched.numel() # find coarse alignment positive tilts - for i in range(reference_tilt_id, tilt_series.shape[0] - 1, 1): - scale_factor = torch.cos(torch.deg2rad(tilt_angles[i : i + 1])) / torch.cos( - torch.deg2rad(tilt_angles[i + 1 : i + 2]) - ) - stretched = stretch_image( # stretch image i + 1 - tilt_series[i + 1], - scale_factor, + for i in range(reference_tilt_id, n_tilts - 1, 1): + # multiply stretch matrix by shift for full alignment + M = T_2d(shifts[i + 1] - shifts[i]) @ stretch_matrix( + tilt_image_dimensions, tilt_axis_angles[i + 1], - ) - stretched = affine_transform_2d( # shift to positions of i - stretched, - T_2d(shifts[i + 1] - shifts[i]), - ) + scale_factor=cos_ta[i : i + 1] / cos_ta[i + 1 : i + 2], + ) # slicing cos_ta ensure gradient calculation + stretched = affine_transform_2d(tilt_series[i + 1], M) non_empty = (stretched != 0) * 1.0 correlation_mask = non_empty * mask stretched = stretched * correlation_mask diff --git a/src/tttsa/projection/project_real.py b/src/tttsa/projection/project_real.py index 9dfb427..b1608c0 100644 --- a/src/tttsa/projection/project_real.py +++ b/src/tttsa/projection/project_real.py @@ -4,14 +4,22 @@ import einops import torch -import torch.nn.functional as F +from cryotypes.projectionmodel import ProjectionModel +from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL from torch_grid_utils import coordinate_grid from torch_image_lerp import insert_into_image_2d from tttsa.affine import affine_transform_2d -from tttsa.transformations import R_2d, Ry, Rz, T, T_2d +from tttsa.transformations import ( + R_2d, + T_2d, + projection_model_to_projection_matrix, +) from tttsa.utils import dft_center, homogenise_coordinates +# update shift +PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] + def common_lines_projection( images: torch.Tensor, @@ -48,9 +56,7 @@ def common_lines_projection( def tomogram_reprojection( tomogram: torch.Tensor, tilt_image_dimensions: Tuple[int, int], - tilt_angles: torch.Tensor, - tilt_axis_angles: torch.Tensor, - shifts: torch.Tensor, + projection_model: ProjectionModel, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predict a projection from an intermediate reconstruction. @@ -59,17 +65,11 @@ def tomogram_reprojection( """ device = tomogram.device tomogram_dimensions = tomogram.shape - tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True) - transform_shape = (tomogram_dimensions[0], *tilt_image_dimensions) - transform_center = dft_center(transform_shape, rfft=False, fftshifted=True) # time for real space projection - s0 = T(-transform_center) - r0 = Ry(tilt_angles, zyx=True) - r1 = Rz(tilt_axis_angles, zyx=True) - s1 = T(F.pad(-shifts, pad=(1, 0), value=0)) - s2 = T(tomogram_center) - M = s2 @ s1 @ r1 @ r0 @ s0 + M = projection_model_to_projection_matrix( + projection_model, tilt_image_dimensions, tomogram_dimensions + ) Mproj = M[:, 1:3, :] Mproj = einops.rearrange(Mproj, "... i j -> ... 1 1 i j").to(device) diff --git a/src/tttsa/projection_matching.py b/src/tttsa/projection_matching.py index 87d4c43..66b4293 100644 --- a/src/tttsa/projection_matching.py +++ b/src/tttsa/projection_matching.py @@ -4,35 +4,42 @@ import einops import torch +from cryotypes.projectionmodel import ProjectionModel +from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL from rich.progress import track from .alignment import find_image_shift from .back_projection import filtered_back_projection_3d from .projection import tomogram_reprojection +# update shift +PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] + def projection_matching( tilt_series: torch.Tensor, - tomogram_dimensions: Tuple[int, int, int], + projection_model_in: ProjectionModel, reference_tilt_id: int, - tilt_angles: torch.Tensor, - tilt_axis_angles: torch.Tensor, - current_shifts: torch.Tensor, alignment_mask: torch.Tensor, + tomogram_dimensions: Tuple[int, int, int], reconstruction_weighting: str = "hamming", exact_weighting_object_diameter: float | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[ProjectionModel, torch.Tensor]: """Run projection matching.""" device = tilt_series.device n_tilts, size, _ = tilt_series.shape aligned_set = [reference_tilt_id] - shifts = current_shifts.detach().clone() + # copy the model to update with new shifts + projection_model_out = projection_model_in.copy(deep=True) + tilt_angles = torch.tensor( # to tensor as we need it to calculate weights + projection_model_out[PMDL.ROTATION_Y].to_numpy(), dtype=tilt_series.dtype + ) # generate indices by alternating postive/negative tilts - max_offset = max(reference_tilt_id, len(tilt_angles) - reference_tilt_id - 1) + max_offset = max(reference_tilt_id, n_tilts - reference_tilt_id - 1) index_sequence = [] for i in range(1, max_offset + 1): # skip reference - if reference_tilt_id + i < len(tilt_angles): + if reference_tilt_id + i < n_tilts: index_sequence.append(reference_tilt_id + i) if i > 0 and reference_tilt_id - i >= 0: index_sequence.append(reference_tilt_id - i) @@ -46,22 +53,18 @@ def projection_matching( weights = einops.rearrange( torch.cos(torch.deg2rad(torch.abs(tilt_angles - tilt_angle))), "n -> n 1 1", - ) - intermediate_recon, _ = filtered_back_projection_3d( - tilt_series[aligned_set,] * weights[aligned_set,].to(device), + ).to(device) + intermediate_recon = filtered_back_projection_3d( + tilt_series[aligned_set,] * weights[aligned_set,], tomogram_dimensions, - tilt_angles[aligned_set,], - tilt_axis_angles[aligned_set,], - shifts[aligned_set,], + projection_model_out.iloc[aligned_set,], weighting=reconstruction_weighting, object_diameter=exact_weighting_object_diameter, ) projection, projection_weights = tomogram_reprojection( intermediate_recon, (size, size), - tilt_angles[[i],], - tilt_axis_angles[[i],], - shifts[[i],], + projection_model_out.iloc[[i],], ) # ensure correlation in relevant area @@ -75,10 +78,10 @@ def projection_matching( raw, projection, ) - shifts[i] -= shift + projection_model_out.loc[i, PMDL.SHIFT] += shift.numpy() aligned_set.append(i) # for debugging: projections[i] = projection.detach().cpu() - return shifts, projections + return projection_model_out, projections diff --git a/src/tttsa/transformations.py b/src/tttsa/transformations.py index b1afb79..4522b43 100644 --- a/src/tttsa/transformations.py +++ b/src/tttsa/transformations.py @@ -1,15 +1,23 @@ """4x4 matrices for rotations and translations. Functions in this module generate matrices which left-multiply column vectors containing -`xyzw` or `zyxw` homogenous coordinates. +`xyzw` or `zyxw` homogeneous coordinates. """ +from typing import Tuple + import einops import torch +import torch.nn.functional as F +from cryotypes.projectionmodel import ProjectionModel +from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL + +# update shift +PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] def Rx(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: - """4x4 matrices for a rotation of homogenous coordinates around the X-axis. + """4x4 matrices for a rotation of homogeneous coordinates around the X-axis. Parameters ---------- @@ -17,7 +25,7 @@ def Rx(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: `(..., )` array of angles zyx: bool Whether output should be compatible with `zyxw` (`True`) or `xyzw` - (`False`) homogenous coordinates. + (`False`) homogeneous coordinates. Returns ------- @@ -41,7 +49,7 @@ def Rx(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: def Ry(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: - """4x4 matrices for a rotation of homogenous coordinates around the Y-axis. + """4x4 matrices for a rotation of homogeneous coordinates around the Y-axis. Parameters ---------- @@ -49,7 +57,7 @@ def Ry(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: `(..., )` array of angles zyx: bool Whether output should be compatible with `zyxw` (`True`) or `xyzw` - (`False`) homogenous coordinates. + (`False`) homogeneous coordinates. Returns ------- @@ -73,7 +81,7 @@ def Ry(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: def Rz(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: - """4x4 matrices for a rotation of homogenous coordinates around the Z-axis. + """4x4 matrices for a rotation of homogeneous coordinates around the Z-axis. Parameters ---------- @@ -81,7 +89,7 @@ def Rz(angles_degrees: torch.Tensor, zyx: bool = False) -> torch.Tensor: `(..., )` array of angles zyx: bool Whether output should be compatible with `zyxw` (`True`) or `xyzw` - (`False`) homogenous coordinates. + (`False`) homogeneous coordinates. Returns ------- @@ -152,7 +160,7 @@ def S(scale_factors: torch.Tensor) -> torch.Tensor: def R_2d(angles_degrees: torch.Tensor, yx: bool = False) -> torch.Tensor: - """3x3 matrices for a rotation of homogenous coordinates around the X-axis. + """3x3 matrices for a rotation of homogeneous coordinates around the X-axis. Parameters ---------- @@ -160,7 +168,7 @@ def R_2d(angles_degrees: torch.Tensor, yx: bool = False) -> torch.Tensor: `(..., )` array of angles yx: bool Whether output should be compatible with `yxw` (`True`) or `xyw` - (`False`) homogenous coordinates. + (`False`) homogeneous coordinates. Returns ------- @@ -225,3 +233,72 @@ def S_2d(scale_factors: torch.Tensor) -> torch.Tensor: matrices[:, [0, 1], [0, 1]] = scale_factors [matrices] = einops.unpack(matrices, packed_shapes=ps, pattern="* i j") return matrices + + +def stretch_matrix( + tilt_image_dimensions: Tuple[int, int], + tilt_axis_angle: torch.Tensor, + scale_factor: torch.Tensor, +) -> torch.Tensor: + """Calculate a tilt-image stretch matrix for coarse alignment.""" + image_center = torch.tensor(tilt_image_dimensions) // 2 + s0 = T_2d(-image_center) + r_forward = R_2d(tilt_axis_angle, yx=True) + r_backward = torch.linalg.inv(r_forward) + m_stretch = torch.eye(3) + m_stretch[1, 1] = scale_factor # this is a shear matrix + s1 = T_2d(image_center) + return s1 @ r_forward @ m_stretch @ r_backward @ s0 + + +def projection_model_to_projection_matrix( + projection_model: ProjectionModel, + tilt_image_dimensions: Tuple[int, int], + tomogram_dimensions: Tuple[int, int, int], +) -> torch.Tensor: + """Convert a cryotypes ProjectionModel to a projection matrix.""" + tilt_image_center = ( + torch.tensor((int(tomogram_dimensions[0]), *tilt_image_dimensions)) // 2 + ) + tomogram_center = torch.tensor(tomogram_dimensions) // 2 + s0 = T(-tilt_image_center) + r0 = Rx(torch.tensor(projection_model[PMDL.ROTATION_X].to_numpy()), zyx=True) + r1 = Ry(torch.tensor(projection_model[PMDL.ROTATION_Y].to_numpy()), zyx=True) + r2 = Rz(torch.tensor(projection_model[PMDL.ROTATION_Z].to_numpy()), zyx=True) + s1 = T( + F.pad( + torch.tensor(projection_model[PMDL.SHIFT].to_numpy()), pad=(1, 0), value=0 + ) + ) + s2 = T(tomogram_center) + return s2 @ s1 @ r2 @ r1 @ r0 @ s0 + + +def projection_model_to_tsa_matrix( + projection_model: ProjectionModel, + tilt_image_dimensions: Tuple[int, int], + projected_tomogram_dimensions: Tuple[int, int], +) -> torch.Tensor: + """Convert cryotypes ProjectionModel to a 2D tilt-series alignment matrix.""" + tilt_image_center = torch.tensor(tilt_image_dimensions) // 2 + projected_tomogram_center = torch.tensor(projected_tomogram_dimensions) // 2 + s0 = T_2d(-tilt_image_center) + r0 = R_2d(torch.tensor(projection_model[PMDL.ROTATION_Z].to_numpy()), yx=True) + s1 = T_2d(torch.tensor(projection_model[PMDL.SHIFT].to_numpy())) + s2 = T_2d(projected_tomogram_center) + # invert for forward alignment and reconstruction + return torch.linalg.inv(s2 @ s1 @ r0 @ s0) + + +def projection_model_to_backproject_matrix( + projection_model: ProjectionModel, + tomogram_dimensions: Tuple[int, int, int], +) -> torch.Tensor: + """Convert cryotypes ProjectionModel to a backprojection matrix.""" + tomogram_center = torch.tensor(tomogram_dimensions) // 2 + s0 = T(-tomogram_center) + r0 = Rx(torch.tensor(projection_model[PMDL.ROTATION_X].to_numpy()), zyx=True) + r1 = Ry(torch.tensor(projection_model[PMDL.ROTATION_Y].to_numpy()), zyx=True) + s1 = T(tomogram_center) + # invert for forward alignment and reconstruction + return torch.linalg.inv(s1 @ r1 @ r0 @ s0) diff --git a/src/tttsa/tttsa.py b/src/tttsa/tttsa.py index 9bed90a..00cd94f 100644 --- a/src/tttsa/tttsa.py +++ b/src/tttsa/tttsa.py @@ -1,8 +1,9 @@ """Main program for aligning tilt-series.""" -from typing import Tuple - +import numpy as np import torch +from cryotypes.projectionmodel import ProjectionModel +from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL from rich.console import Console from rich.progress import track from torch_fourier_shift import fourier_shift_image_2d @@ -12,6 +13,9 @@ from .projection_matching import projection_matching from .utils import circle +# update shift +PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] + # import logging # log = logging.getLogger(__name__) @@ -20,11 +24,10 @@ def tilt_series_alignment( tilt_series: torch.Tensor, - tilt_angle_priors: torch.Tensor, - tilt_axis_angle_prior: torch.Tensor, + projection_model_prior: ProjectionModel, alignment_z_height: int, find_tilt_angle_offset: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> ProjectionModel: """Align a tilt-series using AreTomo-style projection matching. AreTomo paper: @@ -41,7 +44,7 @@ def tilt_series_alignment( size = min(h, w) tomogram_dimensions = (alignment_z_height, size, size) tilt_dimensions = (size,) * 2 - reference_tilt = int(tilt_angle_priors.abs().argmin()) + reference_tilt = int(projection_model_prior[PMDL.ROTATION_Y].abs().argmin()) # mask for coarse alignment coarse_alignment_mask = circle( # ttmask -> tt-shapes; maybe add function @@ -53,12 +56,13 @@ def tilt_series_alignment( console.print("=== Starting teamtomo tilt-series alignment!", style="bold blue") + # make a copy of the ProjectionModel to store alignments in + projection_model = projection_model_prior.copy(deep=True) # do an IMOD style coarse tilt-series alignment - coarse_shifts = coarse_align(tilt_series, reference_tilt, coarse_alignment_mask) + projection_model[PMDL.SHIFT] = -coarse_align( + tilt_series, reference_tilt, coarse_alignment_mask + ).numpy() - tilt_axis_angles = tilt_axis_angle_prior.clone() - shifts = coarse_shifts.clone() - tilt_angles = tilt_angle_priors.clone() start_taa_grid_points = 1 # taa = tilt-axis angle pm_taa_grid_points = 3 # pm = projection matching @@ -66,24 +70,30 @@ def tilt_series_alignment( f"=== Optimizing tilt-axis angle with {start_taa_grid_points} grid point." ) for _ in track(range(3)): # optimize tilt axis angle - tilt_axis_angles = optimize_tilt_axis_angle( - fourier_shift_image_2d(tilt_series, shifts=shifts.to(device)), + projection_model[PMDL.ROTATION_Z] = optimize_tilt_axis_angle( + fourier_shift_image_2d( + tilt_series, + shifts=-torch.as_tensor( + projection_model[PMDL.SHIFT].to_numpy(), device=device + ), + ), coarse_alignment_mask, - tilt_axis_angles, + torch.as_tensor(projection_model[PMDL.ROTATION_Z]), grid_points=start_taa_grid_points, - ) + ).numpy() - shifts = stretch_align( + projection_model[PMDL.SHIFT] = -stretch_align( tilt_series, reference_tilt, coarse_alignment_mask, - tilt_angles, - tilt_axis_angles, - ) + torch.as_tensor(projection_model[PMDL.ROTATION_Y]), + torch.as_tensor(projection_model[PMDL.ROTATION_Z]), + ).numpy() console.print( - f"=== New tilt axis angle: {tilt_axis_angles.mean():.2f}° +-" - f" {tilt_axis_angles.std():.2f}°" + f"=== New tilt axis angle: " + f"{projection_model[PMDL.ROTATION_Z].mean():.2f}° +-" + f" {projection_model[PMDL.ROTATION_Z].std():.2f}°" ) if find_tilt_angle_offset: @@ -93,64 +103,69 @@ def tilt_series_alignment( tilt_angle_offset = optimize_tilt_angle_offset( tilt_series, coarse_alignment_mask, - tilt_angles, - tilt_axis_angles, - shifts, + torch.as_tensor(projection_model[PMDL.ROTATION_Y]), + torch.as_tensor(projection_model[PMDL.ROTATION_Z]), + torch.as_tensor(projection_model[PMDL.SHIFT].to_numpy()), ) full_offset += tilt_angle_offset.detach() - tilt_angles = tilt_angles + tilt_angle_offset.detach() - reference_tilt = int((tilt_angles).abs().argmin()) + projection_model[PMDL.ROTATION_Y] += float(tilt_angle_offset.detach()) + reference_tilt = int(projection_model[PMDL.ROTATION_Y].abs().argmin()) - shifts = stretch_align( + projection_model[PMDL.SHIFT] = -stretch_align( tilt_series, reference_tilt, coarse_alignment_mask, - tilt_angles, - tilt_axis_angles, - ) + torch.as_tensor(projection_model[PMDL.ROTATION_Y]), + torch.as_tensor(projection_model[PMDL.ROTATION_Z]), + ).numpy() console.print(f"=== Detected tilt-angle offset: {full_offset:.2f}°") # some optimizations parameters max_iter = 10 # this seems solid tolerance = 0.1 # should probably be related to pixel size - prev_shifts = shifts.clone() + prev_shifts = projection_model[PMDL.SHIFT].to_numpy() console.print( f"=== Starting projection matching with" f" {pm_taa_grid_points} grid points for the tilt-axis angle." ) for i in range(max_iter): - tilt_axis_angles = optimize_tilt_axis_angle( - fourier_shift_image_2d(tilt_series, shifts=prev_shifts.to(device)), + projection_model[PMDL.ROTATION_Z] = optimize_tilt_axis_angle( + fourier_shift_image_2d( + tilt_series, + shifts=-torch.as_tensor( + projection_model[PMDL.SHIFT].to_numpy(), device=device + ), + ), coarse_alignment_mask, - tilt_axis_angles, + torch.as_tensor(projection_model[PMDL.ROTATION_Z]), grid_points=pm_taa_grid_points, ) - shifts, _ = projection_matching( + projection_model, _ = projection_matching( tilt_series, - tomogram_dimensions, - reference_tilt, # REFERENCE_TILT, - tilt_angles, - tilt_axis_angles, - prev_shifts, + projection_model, + reference_tilt, coarse_alignment_mask, + tomogram_dimensions, ) + shifts = projection_model[PMDL.SHIFT].to_numpy() + abs_diff = np.abs(prev_shifts - shifts) console.print( f"--> Iteration {i + 1}, " f"sum of translation differences =" - f" {torch.abs(prev_shifts - shifts).sum():.2f}" + f" {abs_diff.sum():.2f}" ) - if torch.all(torch.abs(prev_shifts - shifts) < tolerance): + if np.all(abs_diff < tolerance): break prev_shifts = shifts console.print( - f"=== Final tilt-axis angle: {tilt_axis_angles.mean():.2f}° +-" - f" {tilt_axis_angles.std():.2f}°" + f"=== Final tilt-axis angle: {projection_model[PMDL.ROTATION_Z].mean():.2f}° +-" + f" {projection_model[PMDL.ROTATION_Z].std():.2f}°" ) console.print("===== Done!") - return tilt_angles, tilt_axis_angles, shifts + return projection_model diff --git a/tests/affine/test_affine_transform.py b/tests/affine/test_affine_transform.py index dba9495..4c7a586 100644 --- a/tests/affine/test_affine_transform.py +++ b/tests/affine/test_affine_transform.py @@ -1,16 +1,10 @@ import pytest import torch -from tttsa.affine import affine_transform_2d, affine_transform_3d, stretch_image +from tttsa.affine import affine_transform_2d, affine_transform_3d from tttsa.transformations import R_2d, Rz -def test_stretch_image(): - a = torch.zeros((5, 5)) - b = stretch_image(a, 1.1, -85) - assert a.shape == b.shape - - def test_affine_transform_2d(): m1 = R_2d(torch.tensor(45.0)) m2 = R_2d(torch.randn(3))