Skip to content

Commit

Permalink
add skeletonization code (#65)
Browse files Browse the repository at this point in the history
* add skeletonization code

* Second commit

* Second commit

* Second commit

* Second commit

* Third commit

* Third commit

* Fourth commit

* Fourth commit

* Fix data type warning and absolute value error

---------

Co-authored-by: Hanyi Zhang <[email protected]>
Co-authored-by: Hanyi Zhang <[email protected]>
  • Loading branch information
3 people authored Apr 28, 2024
1 parent 69b8585 commit 1317d16
Show file tree
Hide file tree
Showing 14 changed files with 653 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.4.0
hooks:
- id: black

Expand Down
1 change: 1 addition & 0 deletions src/membrain_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""membrane segmentation in 3D for cryo-ET."""

from importlib.metadata import PackageNotFoundError, version

try:
Expand Down
1 change: 1 addition & 0 deletions src/membrain_seg/annotations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""empty init."""

from .cli import cli # noqa: F401
from .extract_patch_cli import extract_patches # noqa: F401
from .merge_corrections_cli import merge_corrections # noqa: F401
2 changes: 2 additions & 0 deletions src/membrain_seg/segmentation/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""CLI init function."""

# These imports are necessary to register CLI commands. Do not remove!
from .cli import cli # noqa: F401
from .segment_cli import segment # noqa: F401
from .ske_cli import skeletonize # noqa: F401
from .train_cli import data_dir_help, train # noqa: F401
64 changes: 64 additions & 0 deletions src/membrain_seg/segmentation/cli/ske_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

from typer import Option

from membrain_seg.segmentation.dataloading.data_utils import store_tomogram

from ..skeletonize import skeletonization as _skeletonization
from .cli import cli


@cli.command(name="skeletonize", no_args_is_help=True)
def skeletonize(
label_path: str = Option(..., help="Specifies the path for skeletonization."),
out_folder: str = Option(
"./predictions", help="Directory to save the resulting skeletons."
),
batch_size: int = Option(
None,
help="Optional batch size for processing the tomogram. If not specified, "
"the entire volume is processed at once. If operating with limited GPU "
"resources, a batch size of 1,000,000 is recommended.",
),
):
"""
Perform skeletonization on labeled tomograms using nonmax-suppression technique.
This function reads a labeled tomogram, applies skeletonization using a specified
batch size, and stores the results in an MRC file in the specified output directory.
If batch_size is set to None, the entire tomogram is processed at once, which might
require significant memory. It is recommended to specify a batch size if memory
constraints are a concern. The maximum possible batch size is the product of the
tomogram's dimensions (Nx * Ny * Nz).
Parameters
----------
label_path : str
File path to the tomogram to be skeletonized.
out_folder : str
Output folder path for the skeletonized tomogram.
batch_size : int, optional
The size of the batch to process the tomogram. Defaults to None, which processes
the entire volume at once. For large volumes, consider setting it to a specific
value like 1,000,000 for efficient processing without exceeding memory limits.
Examples
--------
membrain skeletonize --label-path <path> --out-folder <output-directory>
--batch-size <batch-size>
"""
# Assuming _skeletonization function is already defined and can handle batch_size
ske = _skeletonization(label_path=label_path, batch_size=batch_size)

if not os.path.exists(out_folder):
os.makedirs(out_folder)

out_file = os.path.join(
out_folder,
os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc",
)

store_tomogram(filename=out_file, tomogram=ske)
print("Skeleton saved to ", out_file)
16 changes: 10 additions & 6 deletions src/membrain_seg/segmentation/dataloading/memseg_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,11 @@ def get_training_transforms(
np.random.uniform(np.log(x[y] // 6), np.log(x[y]))
),
loc=(-0.5, 1.5),
max_strength=lambda x, y: np.random.uniform(-5, -1)
if np.random.uniform() < 0.5
else np.random.uniform(1, 5),
max_strength=lambda x, y: (
np.random.uniform(-5, -1)
if np.random.uniform() < 0.5
else np.random.uniform(1, 5)
),
mean_centered=False,
),
prob=(1.0 if prob_to_one else 0.3),
Expand All @@ -268,9 +270,11 @@ def get_training_transforms(
np.random.uniform(np.log(x[y] // 6), np.log(x[y]))
),
loc=(-0.5, 1.5),
gamma=lambda: np.random.uniform(0.01, 0.8)
if np.random.uniform() < 0.5
else np.random.uniform(1.5, 4),
gamma=lambda: (
np.random.uniform(0.01, 0.8)
if np.random.uniform() < 0.5
else np.random.uniform(1.5, 4)
),
),
prob=(1.0 if prob_to_one else 0.3),
),
Expand Down
6 changes: 3 additions & 3 deletions src/membrain_seg/segmentation/dataloading/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ def __call__(self, data):
y = self.R.randint(0, y_max - height)
x = self.R.randint(0, x_max - width)
if self.replace_with == "mean":
image[
..., z : z + depth, y : y + height, x : x + width
] = torch.mean(torch.Tensor(image))
image[..., z : z + depth, y : y + height, x : x + width] = (
torch.mean(torch.Tensor(image))
)
elif self.replace_with == 0.0:
image[..., z : z + depth, y : y + height, x : x + width] = 0.0
d[key] = image
Expand Down
1 change: 1 addition & 0 deletions src/membrain_seg/segmentation/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Neural networks implemented as pytorch lightning modules."""

__all__ = ["SemanticSegmentationUnet"]
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
14 changes: 8 additions & 6 deletions src/membrain_seg/segmentation/networks/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def __init__(

self.loss_function = DeepSuperVisionLoss(
loss_function,
weights=[1.0, 0.5, 0.25, 0.125, 0.0675]
if use_deep_supervision
else [1.0, 0.0, 0.0, 0.0, 0.0],
weights=(
[1.0, 0.5, 0.25, 0.125, 0.0675]
if use_deep_supervision
else [1.0, 0.0, 0.0, 0.0, 0.0]
),
)

# validation metric
Expand Down Expand Up @@ -278,9 +280,9 @@ def validation_step(self, batch, batch_idx):
# compute more stats?
outputs4dice = outputs[0].clone()
labels4dice = labels[0].clone()
outputs4dice[
labels4dice == 2
] = -1.0 # Setting to -1 here leads to 0-labels after thresholding
outputs4dice[labels4dice == 2] = (
-1.0
) # Setting to -1 here leads to 0-labels after thresholding
labels4dice[labels4dice == 2] = 0 # Need to set to zero before post_label
# Otherwise we have 3 classes
outputs4dice = [self.post_pred(i) for i in decollate_batch(outputs4dice)]
Expand Down
Empty file.
147 changes: 147 additions & 0 deletions src/membrain_seg/segmentation/skeletonization/diff3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# ---------------------------------------------------------------------------------
# DISCLAIMER: This code is adapted from the MATLAB and C++ implementations provided
# in the paper titled "Robust membrane detection based on tensor voting for electron
# tomography" by Antonio Martinez-Sanchez, Inmaculada Garcia, Shoh Asano, Vladan Lucic,
# and Jose-Jesus Fernandez, published in the Journal of Structural Biology,
# Volume 186, Issue 1, 2014, Pages 49-61. The original work can be accessed via
# https://doi.org/10.1016/j.jsb.2014.02.015 and is used under conditions that adhere
# to the original licensing agreements. For details on the original license, refer to
# the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495.
# ---------------------------------------------------------------------------------
import numpy as np


def calculate_derivative_3d(tomogram: np.ndarray, axis: int) -> np.ndarray:
"""
Calculate the partial derivative of a 3D tomogram along a specified dimension.
Parameters
----------
tomogram : np.ndarray
The input 3D tomogram as a numpy array, where each dimension
corresponds to spatial dimensions.
axis : int
The axis along which to compute the derivative.
Set axis=0 for the x-dimension, axis=1 for the y-dimension,
and any other value for the z-dimension.
Returns
-------
np.ndarray
The output tomogram,
which represents the partial derivatives along the specified axis.
This output has the same shape as the input array.
Notes
-----
The function computes the centered difference in the specified dimension.
The boundaries are handled by padding the last slice with the value from
the second to last slice, ensuring smooth derivative values at the edges
of the tomogram.
Examples
--------
Create a sample 3D array and compute the partial derivative
along the x-axis (axis=0):
>>> tomogram = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
>>> calculate_derivative_3d(tomogram, 0)
array([[[ 4., 4.],
[ 4., 4.]],
[[ 0., 0.],
[ 0., 0.]]])
"""
# Get the size of the input tomogram
num_x, num_y, num_z = tomogram.shape

# Initialize arrays for forward and backward differences
forward_difference = np.zeros((num_x, num_y, num_z), dtype="float32")
backward_difference = np.zeros((num_x, num_y, num_z), dtype="float32")

# Calculate partial derivatives along the specified dimension (axis)
if axis == 0:
forward_difference[0 : num_x - 1, :, :] = tomogram[1:num_x, :, :]
backward_difference[1:num_x, :, :] = tomogram[0 : num_x - 1, :, :]
# Pad extremes
forward_difference[num_x - 1, :, :] = forward_difference[num_x - 2, :, :]
backward_difference[0, :, :] = backward_difference[1, :, :]
elif axis == 1:
forward_difference[:, 0 : num_y - 1, :] = tomogram[:, 1:num_y, :]
backward_difference[:, 1:num_y, :] = tomogram[:, 0 : num_y - 1, :]
# Pad extremes
forward_difference[:, num_y - 1, :] = forward_difference[:, num_y - 2, :]
backward_difference[:, 0, :] = backward_difference[:, 1, :]
else:
forward_difference[:, :, 0 : num_z - 1] = tomogram[:, :, 1:num_z]
backward_difference[:, :, 1:num_z] = tomogram[:, :, 0 : num_z - 1]
# Pad extremes
forward_difference[:, :, num_z - 1] = forward_difference[:, :, num_z - 2]
backward_difference[:, :, 0] = backward_difference[:, :, 1]

# Calculate the output tomogram
derivative_tomogram = (forward_difference - backward_difference) * 0.5

return derivative_tomogram


def compute_gradients(tomogram: np.ndarray) -> tuple:
"""
Computes the gradients along each spatial dimension of a 3D tomogram.
Parameters
----------
tomogram : np.ndarray
The input 3D tomogram as a numpy array.
Returns
-------
tuple
A tuple containing the gradient components (gradX, gradY, gradZ).
Notes
-----
This function calculates the partial derivatives of the tomogram along the x, y,
and z dimensions, respectively. These derivatives represent the gradient components
along each dimension.
"""
gradX = calculate_derivative_3d(tomogram, 0)
gradY = calculate_derivative_3d(tomogram, 1)
gradZ = calculate_derivative_3d(tomogram, 2)

return gradX, gradY, gradZ


def compute_hessian(gradX: np.ndarray, gradY: np.ndarray, gradZ: np.ndarray) -> tuple:
"""
Computes the Hessian tensor components for a 3D tomogram from its gradients.
Parameters
----------
gradX : np.ndarray
Gradient of the tomogram along the x-axis.
gradY : np.ndarray
Gradient of the tomogram along the y-axis.
gradZ : np.ndarray
Gradient of the tomogram along the z-axis.
Returns
-------
tuple
A tuple containing the Hessian tensor components (hessianXX, hessianYY,
hessianZZ, hessianXY, hessianXZ, hessianYZ).
Notes
-----
This function computes the second derivatives of the tomogram along each dimension.
These second derivatives form the components of the Hessian tensor, providing
information about the curvature of the tomogram.
"""
hessianXX = calculate_derivative_3d(gradX, 0)
hessianYY = calculate_derivative_3d(gradY, 1)
hessianZZ = calculate_derivative_3d(gradZ, 2)
hessianXY = calculate_derivative_3d(gradX, 1)
hessianXZ = calculate_derivative_3d(gradX, 2)
hessianYZ = calculate_derivative_3d(gradY, 2)

return hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ
Loading

0 comments on commit 1317d16

Please sign in to comment.