Skip to content

Commit

Permalink
feat: add rich progress bars and printing
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet committed Nov 8, 2024
1 parent 8bb7d52 commit fe92c9a
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"einops",
"numpy",
"scipy",
"rich", # https://github.com/Textualize/rich
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
Expand All @@ -55,7 +56,6 @@ dev = [
"mypy",
"pdbpp", # https://github.com/pdbpp/pdbpp
"pre-commit",
"rich", # https://github.com/Textualize/rich
"ruff",
"mrcfile",
"torch-fourier-rescale",
Expand Down
8 changes: 4 additions & 4 deletions src/tttsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
__author__ = "Marten Chaillet"
__email__ = "[email protected]"

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# import logging
# import sys
#
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
2 changes: 0 additions & 2 deletions src/tttsa/coarse_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def coarse_align(
mask: torch.Tensor,
) -> torch.Tensor:
"""Find coarse shifts of images without stretching along tilt axis."""
print("finding translational alignment using pure cross-correlation")
shifts = torch.zeros((len(tilt_series), 2), dtype=torch.float32)
# find coarse alignment for negative tilts
current_shift = torch.zeros(2)
Expand Down Expand Up @@ -41,7 +40,6 @@ def stretch_align(
tilt_axis_angles: torch.Tensor,
) -> torch.Tensor:
"""Find coarse shifts of images while stretching each pair along the tilt axis."""
print("finding translational alignment using pairwise stretching")
shifts = torch.zeros((len(tilt_series), 2), dtype=torch.float32)
# find coarse alignment for negative tilts
current_shift = torch.zeros(2)
Expand Down
5 changes: 0 additions & 5 deletions src/tttsa/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def optimize_tilt_axis_angle(
grid_points: int = 1,
) -> torch.Tensor:
"""Optimize tilt axis angles on a spline grid using the LBFGS optimizer."""
print(
f"optimizing the tilt-axis angle (in-plane rotations) with {grid_points} "
f"grid point(s)"
)

coarse_aligned_masked = aligned_ts * coarse_alignment_mask

# generate a weighting for the common line ROI by projecting the mask
Expand Down
7 changes: 2 additions & 5 deletions src/tttsa/projection_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import einops
import torch
from rich.progress import track

from .alignment import find_image_shift
from .back_projection import filtered_back_projection_3d
Expand Down Expand Up @@ -39,7 +40,7 @@ def projection_matching(
projections = torch.zeros((n_tilts, size, size))
projections[reference_tilt_id] = tilt_series[reference_tilt_id]

for i in index_sequence:
for i in track(index_sequence):
tilt_angle = tilt_angles[i]
weights = einops.rearrange(
torch.cos(torch.deg2rad(torch.abs(tilt_angles - tilt_angle))),
Expand Down Expand Up @@ -79,8 +80,4 @@ def projection_matching(
# for debugging:
projections[i] = projection.detach().to("cpu")

print( # TODO should be some sort of logging?
f"aligned index {i} at angle {tilt_angle:.2f}: {shift}"
)

return shifts, projections
58 changes: 41 additions & 17 deletions src/tttsa/tttsa.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""Main program for aligning tilt-series."""

# set logger => TODO use rich instead
import logging
from typing import Tuple

import torch
from rich.console import Console
from rich.progress import track
from torch_fourier_shift import fourier_shift_image_2d

from .coarse_align import coarse_align, stretch_align
from .optimizers import optimize_tilt_angle_offset, optimize_tilt_axis_angle
from .projection_matching import projection_matching
from .utils import circle

log = logging.getLogger(__name__)
# import logging
# log = logging.getLogger(__name__)

console = Console()


def tilt_series_alignment(
Expand Down Expand Up @@ -48,22 +51,23 @@ def tilt_series_alignment(

# do an IMOD style coarse tilt-series alignment
coarse_shifts = coarse_align(tilt_series, reference_tilt, coarse_alignment_mask)
coarse_aligned = fourier_shift_image_2d(tilt_series, shifts=coarse_shifts)

tilt_axis_angles = torch.tensor(tilt_axis_angle_prior)
shifts = coarse_shifts.clone()
tilt_angles = tilt_angle_priors.clone()

for _ in range(3): # optimize tilt axis angle
start_taa_grid_points = 1 # taa = tilt-axis angle
pm_taa_grid_points = 3 # pm = projection matching
console.print("=== Starting teamtomo tilt-series alignment!", style="bold blue")
console.print(
f"=== Optimizing tilt-axis angle with {start_taa_grid_points} grid point."
)
for _ in track(range(3)): # optimize tilt axis angle
tilt_axis_angles = optimize_tilt_axis_angle(
coarse_aligned,
fourier_shift_image_2d(tilt_series, shifts=shifts),
coarse_alignment_mask,
tilt_axis_angles,
grid_points=start_taa_grid_points,
)
print(
f"new tilt axis angle: {tilt_axis_angles.mean():.2f} +-"
f" {tilt_axis_angles.std():.2f}"
) # use rich logging?

shifts = stretch_align(
tilt_series,
Expand All @@ -73,18 +77,23 @@ def tilt_series_alignment(
tilt_axis_angles,
)

coarse_aligned = fourier_shift_image_2d(tilt_series, shifts=shifts)
console.print(
f"=== New tilt axis angle: {tilt_axis_angles.mean():.2f}° +-"
f" {tilt_axis_angles.std():.2f}°"
)

if find_tilt_angle_offset:
for _ in range(3):
full_offset = torch.tensor(0.0)
console.print("=== Optimizing tilt-angle offset.")
for _ in track(range(3)):
tilt_angle_offset = optimize_tilt_angle_offset(
tilt_series,
coarse_alignment_mask,
tilt_angles,
tilt_axis_angles,
shifts,
)
print(f"detected tilt angle offset: {tilt_angle_offset}")
full_offset += tilt_angle_offset.detach()
tilt_angles = tilt_angles + tilt_angle_offset.detach()
reference_tilt = int((tilt_angles).abs().argmin())

Expand All @@ -95,20 +104,23 @@ def tilt_series_alignment(
tilt_angles,
tilt_axis_angles,
)
console.print(f"=== Detected tilt-angle offset: {full_offset:.2f}°")

# some optimizations parameters
max_iter = 10 # this seems solid
tolerance = 0.1 # should probably be related to pixel size
predicted_tilts = []
console.print(
f"=== Starting projection matching with"
f" {pm_taa_grid_points} grid points for the tilt-axis angle."
)
for i in range(max_iter):
print(f"projection matching iteration {i}")
tilt_axis_angles = optimize_tilt_axis_angle(
fourier_shift_image_2d(tilt_series, shifts=shifts),
coarse_alignment_mask,
tilt_axis_angles,
grid_points=3,
grid_points=pm_taa_grid_points,
)
print("new tilt axis angle:", tilt_axis_angles)

new_shifts, pred = projection_matching(
tilt_series,
Expand All @@ -121,9 +133,21 @@ def tilt_series_alignment(
)
predicted_tilts.append(pred)

console.print(
f"--> Iteration {i + 1}, "
f"sum of translation differences ="
f" {torch.abs(shifts - new_shifts).sum():.2f}"
)

if torch.all(torch.abs(shifts - new_shifts) < tolerance):
break

shifts = new_shifts

console.print(
f"=== Final tilt-axis angle: {tilt_axis_angles.mean():.2f}° +-"
f" {tilt_axis_angles.std():.2f}°"
)
console.print("===== Done!")

return tilt_angles, tilt_axis_angles, shifts

0 comments on commit fe92c9a

Please sign in to comment.