diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b336c2..dd27650 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: args: [--fix] - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 24.4.0 hooks: - id: black diff --git a/src/membrain_seg/__init__.py b/src/membrain_seg/__init__.py index 8b09afd..2513a51 100644 --- a/src/membrain_seg/__init__.py +++ b/src/membrain_seg/__init__.py @@ -1,4 +1,5 @@ """membrane segmentation in 3D for cryo-ET.""" + from importlib.metadata import PackageNotFoundError, version try: diff --git a/src/membrain_seg/annotations/__init__.py b/src/membrain_seg/annotations/__init__.py index 8ae0cfd..2112f97 100644 --- a/src/membrain_seg/annotations/__init__.py +++ b/src/membrain_seg/annotations/__init__.py @@ -1,4 +1,5 @@ """empty init.""" + from .cli import cli # noqa: F401 from .extract_patch_cli import extract_patches # noqa: F401 from .merge_corrections_cli import merge_corrections # noqa: F401 diff --git a/src/membrain_seg/segmentation/cli/__init__.py b/src/membrain_seg/segmentation/cli/__init__.py index ebbaa73..a2c078c 100644 --- a/src/membrain_seg/segmentation/cli/__init__.py +++ b/src/membrain_seg/segmentation/cli/__init__.py @@ -1,5 +1,7 @@ """CLI init function.""" + # These imports are necessary to register CLI commands. Do not remove! from .cli import cli # noqa: F401 from .segment_cli import segment # noqa: F401 +from .ske_cli import skeletonize # noqa: F401 from .train_cli import data_dir_help, train # noqa: F401 diff --git a/src/membrain_seg/segmentation/cli/ske_cli.py b/src/membrain_seg/segmentation/cli/ske_cli.py new file mode 100644 index 0000000..a4fdcfe --- /dev/null +++ b/src/membrain_seg/segmentation/cli/ske_cli.py @@ -0,0 +1,64 @@ +import os + +from typer import Option + +from membrain_seg.segmentation.dataloading.data_utils import store_tomogram + +from ..skeletonize import skeletonization as _skeletonization +from .cli import cli + + +@cli.command(name="skeletonize", no_args_is_help=True) +def skeletonize( + label_path: str = Option(..., help="Specifies the path for skeletonization."), + out_folder: str = Option( + "./predictions", help="Directory to save the resulting skeletons." + ), + batch_size: int = Option( + None, + help="Optional batch size for processing the tomogram. If not specified, " + "the entire volume is processed at once. If operating with limited GPU " + "resources, a batch size of 1,000,000 is recommended.", + ), +): + """ + Perform skeletonization on labeled tomograms using nonmax-suppression technique. + + This function reads a labeled tomogram, applies skeletonization using a specified + batch size, and stores the results in an MRC file in the specified output directory. + If batch_size is set to None, the entire tomogram is processed at once, which might + require significant memory. It is recommended to specify a batch size if memory + constraints are a concern. The maximum possible batch size is the product of the + tomogram's dimensions (Nx * Ny * Nz). + + + Parameters + ---------- + label_path : str + File path to the tomogram to be skeletonized. + out_folder : str + Output folder path for the skeletonized tomogram. + batch_size : int, optional + The size of the batch to process the tomogram. Defaults to None, which processes + the entire volume at once. For large volumes, consider setting it to a specific + value like 1,000,000 for efficient processing without exceeding memory limits. + + + Examples + -------- + membrain skeletonize --label-path --out-folder + --batch-size + """ + # Assuming _skeletonization function is already defined and can handle batch_size + ske = _skeletonization(label_path=label_path, batch_size=batch_size) + + if not os.path.exists(out_folder): + os.makedirs(out_folder) + + out_file = os.path.join( + out_folder, + os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc", + ) + + store_tomogram(filename=out_file, tomogram=ske) + print("Skeleton saved to ", out_file) diff --git a/src/membrain_seg/segmentation/dataloading/memseg_augmentation.py b/src/membrain_seg/segmentation/dataloading/memseg_augmentation.py index 96a3091..01f55b1 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_augmentation.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_augmentation.py @@ -254,9 +254,11 @@ def get_training_transforms( np.random.uniform(np.log(x[y] // 6), np.log(x[y])) ), loc=(-0.5, 1.5), - max_strength=lambda x, y: np.random.uniform(-5, -1) - if np.random.uniform() < 0.5 - else np.random.uniform(1, 5), + max_strength=lambda x, y: ( + np.random.uniform(-5, -1) + if np.random.uniform() < 0.5 + else np.random.uniform(1, 5) + ), mean_centered=False, ), prob=(1.0 if prob_to_one else 0.3), @@ -268,9 +270,11 @@ def get_training_transforms( np.random.uniform(np.log(x[y] // 6), np.log(x[y])) ), loc=(-0.5, 1.5), - gamma=lambda: np.random.uniform(0.01, 0.8) - if np.random.uniform() < 0.5 - else np.random.uniform(1.5, 4), + gamma=lambda: ( + np.random.uniform(0.01, 0.8) + if np.random.uniform() < 0.5 + else np.random.uniform(1.5, 4) + ), ), prob=(1.0 if prob_to_one else 0.3), ), diff --git a/src/membrain_seg/segmentation/dataloading/transforms.py b/src/membrain_seg/segmentation/dataloading/transforms.py index 2c31bd9..71da1c6 100644 --- a/src/membrain_seg/segmentation/dataloading/transforms.py +++ b/src/membrain_seg/segmentation/dataloading/transforms.py @@ -288,9 +288,9 @@ def __call__(self, data): y = self.R.randint(0, y_max - height) x = self.R.randint(0, x_max - width) if self.replace_with == "mean": - image[ - ..., z : z + depth, y : y + height, x : x + width - ] = torch.mean(torch.Tensor(image)) + image[..., z : z + depth, y : y + height, x : x + width] = ( + torch.mean(torch.Tensor(image)) + ) elif self.replace_with == 0.0: image[..., z : z + depth, y : y + height, x : x + width] = 0.0 d[key] = image diff --git a/src/membrain_seg/segmentation/networks/__init__.py b/src/membrain_seg/segmentation/networks/__init__.py index f15810c..3968420 100644 --- a/src/membrain_seg/segmentation/networks/__init__.py +++ b/src/membrain_seg/segmentation/networks/__init__.py @@ -1,3 +1,4 @@ """Neural networks implemented as pytorch lightning modules.""" + __all__ = ["SemanticSegmentationUnet"] from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet diff --git a/src/membrain_seg/segmentation/networks/unet.py b/src/membrain_seg/segmentation/networks/unet.py index 711f1d1..a387121 100644 --- a/src/membrain_seg/segmentation/networks/unet.py +++ b/src/membrain_seg/segmentation/networks/unet.py @@ -140,9 +140,11 @@ def __init__( self.loss_function = DeepSuperVisionLoss( loss_function, - weights=[1.0, 0.5, 0.25, 0.125, 0.0675] - if use_deep_supervision - else [1.0, 0.0, 0.0, 0.0, 0.0], + weights=( + [1.0, 0.5, 0.25, 0.125, 0.0675] + if use_deep_supervision + else [1.0, 0.0, 0.0, 0.0, 0.0] + ), ) # validation metric @@ -278,9 +280,9 @@ def validation_step(self, batch, batch_idx): # compute more stats? outputs4dice = outputs[0].clone() labels4dice = labels[0].clone() - outputs4dice[ - labels4dice == 2 - ] = -1.0 # Setting to -1 here leads to 0-labels after thresholding + outputs4dice[labels4dice == 2] = ( + -1.0 + ) # Setting to -1 here leads to 0-labels after thresholding labels4dice[labels4dice == 2] = 0 # Need to set to zero before post_label # Otherwise we have 3 classes outputs4dice = [self.post_pred(i) for i in decollate_batch(outputs4dice)] diff --git a/src/membrain_seg/segmentation/skeletonization/__init__.py b/src/membrain_seg/segmentation/skeletonization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/membrain_seg/segmentation/skeletonization/diff3d.py b/src/membrain_seg/segmentation/skeletonization/diff3d.py new file mode 100644 index 0000000..a90ec97 --- /dev/null +++ b/src/membrain_seg/segmentation/skeletonization/diff3d.py @@ -0,0 +1,147 @@ +# --------------------------------------------------------------------------------- +# DISCLAIMER: This code is adapted from the MATLAB and C++ implementations provided +# in the paper titled "Robust membrane detection based on tensor voting for electron +# tomography" by Antonio Martinez-Sanchez, Inmaculada Garcia, Shoh Asano, Vladan Lucic, +# and Jose-Jesus Fernandez, published in the Journal of Structural Biology, +# Volume 186, Issue 1, 2014, Pages 49-61. The original work can be accessed via +# https://doi.org/10.1016/j.jsb.2014.02.015 and is used under conditions that adhere +# to the original licensing agreements. For details on the original license, refer to +# the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495. +# --------------------------------------------------------------------------------- +import numpy as np + + +def calculate_derivative_3d(tomogram: np.ndarray, axis: int) -> np.ndarray: + """ + Calculate the partial derivative of a 3D tomogram along a specified dimension. + + Parameters + ---------- + tomogram : np.ndarray + The input 3D tomogram as a numpy array, where each dimension + corresponds to spatial dimensions. + axis : int + The axis along which to compute the derivative. + Set axis=0 for the x-dimension, axis=1 for the y-dimension, + and any other value for the z-dimension. + + Returns + ------- + np.ndarray + The output tomogram, + which represents the partial derivatives along the specified axis. + This output has the same shape as the input array. + + Notes + ----- + The function computes the centered difference in the specified dimension. + The boundaries are handled by padding the last slice with the value from + the second to last slice, ensuring smooth derivative values at the edges + of the tomogram. + + Examples + -------- + Create a sample 3D array and compute the partial derivative + along the x-axis (axis=0): + + >>> tomogram = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + >>> calculate_derivative_3d(tomogram, 0) + array([[[ 4., 4.], + [ 4., 4.]], + + [[ 0., 0.], + [ 0., 0.]]]) + """ + # Get the size of the input tomogram + num_x, num_y, num_z = tomogram.shape + + # Initialize arrays for forward and backward differences + forward_difference = np.zeros((num_x, num_y, num_z), dtype="float32") + backward_difference = np.zeros((num_x, num_y, num_z), dtype="float32") + + # Calculate partial derivatives along the specified dimension (axis) + if axis == 0: + forward_difference[0 : num_x - 1, :, :] = tomogram[1:num_x, :, :] + backward_difference[1:num_x, :, :] = tomogram[0 : num_x - 1, :, :] + # Pad extremes + forward_difference[num_x - 1, :, :] = forward_difference[num_x - 2, :, :] + backward_difference[0, :, :] = backward_difference[1, :, :] + elif axis == 1: + forward_difference[:, 0 : num_y - 1, :] = tomogram[:, 1:num_y, :] + backward_difference[:, 1:num_y, :] = tomogram[:, 0 : num_y - 1, :] + # Pad extremes + forward_difference[:, num_y - 1, :] = forward_difference[:, num_y - 2, :] + backward_difference[:, 0, :] = backward_difference[:, 1, :] + else: + forward_difference[:, :, 0 : num_z - 1] = tomogram[:, :, 1:num_z] + backward_difference[:, :, 1:num_z] = tomogram[:, :, 0 : num_z - 1] + # Pad extremes + forward_difference[:, :, num_z - 1] = forward_difference[:, :, num_z - 2] + backward_difference[:, :, 0] = backward_difference[:, :, 1] + + # Calculate the output tomogram + derivative_tomogram = (forward_difference - backward_difference) * 0.5 + + return derivative_tomogram + + +def compute_gradients(tomogram: np.ndarray) -> tuple: + """ + Computes the gradients along each spatial dimension of a 3D tomogram. + + Parameters + ---------- + tomogram : np.ndarray + The input 3D tomogram as a numpy array. + + Returns + ------- + tuple + A tuple containing the gradient components (gradX, gradY, gradZ). + + Notes + ----- + This function calculates the partial derivatives of the tomogram along the x, y, + and z dimensions, respectively. These derivatives represent the gradient components + along each dimension. + """ + gradX = calculate_derivative_3d(tomogram, 0) + gradY = calculate_derivative_3d(tomogram, 1) + gradZ = calculate_derivative_3d(tomogram, 2) + + return gradX, gradY, gradZ + + +def compute_hessian(gradX: np.ndarray, gradY: np.ndarray, gradZ: np.ndarray) -> tuple: + """ + Computes the Hessian tensor components for a 3D tomogram from its gradients. + + Parameters + ---------- + gradX : np.ndarray + Gradient of the tomogram along the x-axis. + gradY : np.ndarray + Gradient of the tomogram along the y-axis. + gradZ : np.ndarray + Gradient of the tomogram along the z-axis. + + Returns + ------- + tuple + A tuple containing the Hessian tensor components (hessianXX, hessianYY, + hessianZZ, hessianXY, hessianXZ, hessianYZ). + + Notes + ----- + This function computes the second derivatives of the tomogram along each dimension. + These second derivatives form the components of the Hessian tensor, providing + information about the curvature of the tomogram. + """ + hessianXX = calculate_derivative_3d(gradX, 0) + hessianYY = calculate_derivative_3d(gradY, 1) + hessianZZ = calculate_derivative_3d(gradZ, 2) + hessianXY = calculate_derivative_3d(gradX, 1) + hessianXZ = calculate_derivative_3d(gradX, 2) + hessianYZ = calculate_derivative_3d(gradY, 2) + + return hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ diff --git a/src/membrain_seg/segmentation/skeletonization/eig3d.py b/src/membrain_seg/segmentation/skeletonization/eig3d.py new file mode 100644 index 0000000..91eb8fd --- /dev/null +++ b/src/membrain_seg/segmentation/skeletonization/eig3d.py @@ -0,0 +1,116 @@ +# --------------------------------------------------------------------------------- +# DISCLAIMER: This code is adapted from the MATLAB and C++ implementations provided +# in the paper titled "Robust membrane detection based on tensor voting for electron +# tomography" by Antonio Martinez-Sanchez, Inmaculada Garcia, Shoh Asano, Vladan Lucic, +# and Jose-Jesus Fernandez, published in the Journal of Structural Biology, +# Volume 186, Issue 1, 2014, Pages 49-61. The original work can be accessed via +# https://doi.org/10.1016/j.jsb.2014.02.015 and is used under conditions that adhere +# to the original licensing agreements. For details on the original license, refer to +# the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495. +# --------------------------------------------------------------------------------- +from typing import List, Tuple + +import numpy as np +import torch + + +def batch_mask_eigendecomposition_3d( + filtered_hessian: List[torch.Tensor], batch_size: int, labels: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """ + Perform batch eigendecomposition on a 3D Hessian matrix using a binary mask to + select voxels. + + This function processes only those voxels where the label is set to 1, + computing the largest eigenvalue and its corresponding eigenvector for + each selected voxel. It handles large 3D datasets efficiently by performing + computations in batches and leveraging GPU acceleration. + + Parameters + ---------- + filtered_hessian : List[torch.Tensor] + A list of six torch.Tensors representing the Hessian matrix components: + [hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ] + batch_size : int + The number of matrices to include in each batch for processing. + labels : np.ndarray + A 3D numpy array representing the binary mask where 1 indicates a voxel + to be processed. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + A tuple containing two numpy arrays: + - first_eigenvalues: + A 3D array with the largest eigenvalues for the processed voxels. + - first_eigenvectors: + A 4D array with the corresponding eigenvectors for these eigenvalues. + """ + hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ = filtered_hessian + del filtered_hessian + Nx, Ny, Nz = hessianXX.shape + + # Set the batch size to the total number of voxels + # if no specified batch size is given + if batch_size is None: + batch_size = Nx * Ny * Nz + print("batch_size=", batch_size) + + # Identify coordinates where computation is needed + active_voxel_coords = np.where(labels == 1) + x_indices, y_indices, z_indices = active_voxel_coords + num_active_voxels = x_indices.shape[0] + + # Prepare a tensor stack for the selected Hessian matrix components + hessian_components = torch.stack( + [ + hessianXX[x_indices, y_indices, z_indices], + hessianXY[x_indices, y_indices, z_indices], + hessianXZ[x_indices, y_indices, z_indices], + hessianXY[x_indices, y_indices, z_indices], + hessianYY[x_indices, y_indices, z_indices], + hessianYZ[x_indices, y_indices, z_indices], + hessianXZ[x_indices, y_indices, z_indices], + hessianYZ[x_indices, y_indices, z_indices], + hessianZZ[x_indices, y_indices, z_indices], + ], + dim=-1, + ).view(-1, 3, 3) + del hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ + print("Hessian component matrix shape:", hessian_components.shape) + + # Initialize output arrays + first_eigenvalues = np.zeros((Nx, Ny, Nz), dtype=np.float32) + first_eigenvectors = np.zeros((Nx, Ny, Nz, 3), dtype=np.float32) + + # Process in batches + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for i in range(0, num_active_voxels, batch_size): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # print('i=', i) + # print(f"Allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB") + # print(f"Cached: {torch.cuda.memory_reserved(0)/1e9:.2f} GB") + + i_end = min(i + batch_size, num_active_voxels) + batch_matrix = hessian_components[i:i_end, :, :] + + # Compute eigenvalues and eigenvectors for this batch + eigenvalues, eigenvectors = torch.linalg.eig(batch_matrix.to(device)) + max_eigenvalue_idx = torch.argmax(torch.abs(eigenvalues), dim=-1) + batch_first_eigenvalues = eigenvalues[ + torch.arange(len(max_eigenvalue_idx)), max_eigenvalue_idx + ] + batch_first_eigenvectors = eigenvectors[ + torch.arange(len(max_eigenvalue_idx)), :, max_eigenvalue_idx + ] + + # Store results back to CPU to save cuda memory + first_eigenvalues[ + x_indices[i:i_end], y_indices[i:i_end], z_indices[i:i_end] + ] = batch_first_eigenvalues.cpu().numpy().real + first_eigenvectors[ + x_indices[i:i_end], y_indices[i:i_end], z_indices[i:i_end], : + ] = (batch_first_eigenvectors.view(-1, 3).cpu().numpy()).real + + return first_eigenvalues, first_eigenvectors diff --git a/src/membrain_seg/segmentation/skeletonization/nonmaxsup.py b/src/membrain_seg/segmentation/skeletonization/nonmaxsup.py new file mode 100644 index 0000000..678dc52 --- /dev/null +++ b/src/membrain_seg/segmentation/skeletonization/nonmaxsup.py @@ -0,0 +1,178 @@ +# --------------------------------------------------------------------------------- +# DISCLAIMER: This code is adapted from the MATLAB and C++ implementations provided +# in the paper titled "Robust membrane detection based on tensor voting for electron +# tomography" by Antonio Martinez-Sanchez, Inmaculada Garcia, Shoh Asano, Vladan Lucic, +# and Jose-Jesus Fernandez, published in the Journal of Structural Biology, +# Volume 186, Issue 1, 2014, Pages 49-61. The original work can be accessed via +# https://doi.org/10.1016/j.jsb.2014.02.015 and is used under conditions that adhere +# to the original licensing agreements. For details on the original license, refer to +# the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495. +# --------------------------------------------------------------------------------- +from typing import List, Tuple + +import numpy as np + + +def nonmaxsup_kernel( + image: np.ndarray, + vector_x: np.ndarray, + vector_y: np.ndarray, + vector_z: np.ndarray, + mask_coords: List[Tuple[int, int, int]], + interpolation_factor: float, +) -> np.ndarray: + """ + Apply non-maximum suppression based on trilinear interpolation to + enhance ridge structures in a 3D image. + + This function adjusts the influence of eigenvectors at each voxel + using a given interpolation factor, and marks local maxima + in the output array indicating significant ridge features. + + Parameters + ---------- + image : np.ndarray + The 3D image data array from which ridges are to be enhanced. + Dimension: (Nx, Ny, Nz). + vector_x : np.ndarray + The x-component of the eigenvector associated with the largest + eigenvalue at each voxel. + vector_y : np.ndarray + The y-component of the eigenvector. + vector_z : np.ndarray + The z-component of the eigenvector. + mask_coords : List[Tuple[int, int, int]] + A list of coordinates where the non-maximum suppression is to + be applied. + interpolation_factor : float + A factor used in the trilinear interpolation for calculating + the adjacent values. + + Returns + ------- + np.ndarray + A 3D binary array the same size as `image`, containing 1s at voxels + identified as local maxima and 0s elsewhere. + """ + # Initialize the output suppression matrix + result = np.zeros_like(image) + + # Convert mask_coords to a structured NumPy array for efficient indexing + coords = np.array(mask_coords, dtype=[("x", int), ("y", int), ("z", int)]) + x, y, z = coords["x"], coords["y"], coords["z"] + + # Compute normalized interpolation coefficients based on + # the eigenvector components and interpolation factor + dx = np.abs(vector_x[x, y, z] * interpolation_factor) + dy = np.abs(vector_y[x, y, z] * interpolation_factor) + dz = np.abs(vector_z[x, y, z] * interpolation_factor) + + # Calculate indices for forward and backward interpolation + # based on the directionality of the eigenvector components + next_x = x + np.sign(dx).astype(int) + next_y = y + np.sign(dy).astype(int) + next_z = z + np.sign(dz).astype(int) + prev_x = x - np.sign(dx).astype(int) + prev_y = y - np.sign(dy).astype(int) + prev_z = z - np.sign(dz).astype(int) + + # Calculate trilinear interpolated values + # for forward (+ve) and backward (-ve) directions + interpolated_values_forward = ( + image[x, y, z] * (1 - dx) * (1 - dy) * (1 - dz) + + image[next_x, y, z] * dx * (1 - dy) * (1 - dz) + + image[x, next_y, z] * (1 - dx) * dy * (1 - dz) + + image[x, y, next_z] * (1 - dx) * (1 - dy) * dz + + image[next_x, next_y, z] * dx * dy * (1 - dz) + + image[next_x, y, next_z] * dx * (1 - dy) * dz + + image[x, next_y, next_z] * (1 - dx) * dy * dz + + image[next_x, next_y, next_z] * dx * dy * dz + ) + + interpolated_values_backward = ( + image[x, y, z] * (1 - dx) * (1 - dy) * (1 - dz) + + image[prev_x, y, z] * dx * (1 - dy) * (1 - dz) + + image[x, prev_y, z] * (1 - dx) * dy * (1 - dz) + + image[x, y, prev_z] * (1 - dx) * (1 - dy) * dz + + image[prev_x, prev_y, z] * dx * dy * (1 - dz) + + image[prev_x, y, prev_z] * dx * (1 - dy) * dz + + image[x, prev_y, prev_z] * (1 - dx) * dy * dz + + image[prev_x, prev_y, prev_z] * dx * dy * dz + ) + + # Local values from image at specified coordinates + local_values = image[x, y, z] + + # Determine local maxima by comparing local values to interpolated values + result[x, y, z] = (local_values > interpolated_values_forward) & ( + local_values > interpolated_values_backward + ) + return result + + +def nonmaxsup( + eigenvalues: np.ndarray, + eigenvector_x: np.ndarray, + eigenvector_y: np.ndarray, + eigenvector_z: np.ndarray, + labels: np.ndarray, +) -> np.ndarray: + """ + Perform non-maximum suppression on the given tomogram to detect ridge centrelines. + + This function applies a non-maximum suppression algorithm to identify and enhance + ridge-like structures in volumetric data based on eigenvalues and the major + eigenvector's components. The process involves masking with the input labels to + focus only on regions of interest and suppressing non-ridge areas. + + Parameters + ---------- + eigenvalues : np.ndarray + The eigenvalues of the Hessian matrix, used to identify potential ridge points. + eigenvector_x : np.ndarray + X-component of the principal eigenvector associated with the largest eigenvalue. + eigenvector_y : np.ndarray + Y-component of the principal eigenvector. + eigenvector_z : np.ndarray + Z-component of the principal eigenvector. + labels : np.ndarray + A binary mask where 1 indicates regions of interest and 0 indicates background. + + Returns + ------- + np.ndarray + A binary array where 1 indicates detected ridge centreline + and 0 indicates background. + + Notes + ----- + The non-maximum suppression is focused within the regions specified by the labels. + The algorithm leverages an interpolation factor to adjust the suppression + sensitivity and is implemented through a specific kernel function + for efficient processing. + """ + # Define the boundary for processing to avoid edge effects + Nx, Ny, Nz = eigenvalues.shape + margin = 1 + mask = np.zeros((Nx, Ny, Nz)) + mask[margin : Nx - margin, margin : Ny - margin, margin : Nz - margin] = 1 + masked_labels = labels * mask + + # Filter coordinates where suppression is applicable + relevant_coords = np.where(masked_labels == 1) + coordinates_list = list( + zip(relevant_coords[0], relevant_coords[1], relevant_coords[2]) + ) + + # Define interpolation factor for kernel processing + interpolation_factor = 0.71 + binary_output = nonmaxsup_kernel( + eigenvalues, + eigenvector_x, + eigenvector_y, + eigenvector_z, + coordinates_list, + interpolation_factor, + ) + + return binary_output diff --git a/src/membrain_seg/segmentation/skeletonize.py b/src/membrain_seg/segmentation/skeletonize.py new file mode 100644 index 0000000..82b9a1e --- /dev/null +++ b/src/membrain_seg/segmentation/skeletonize.py @@ -0,0 +1,121 @@ +# --------------------------------------------------------------------------------- +# DISCLAIMER: This code is adapted from the MATLAB and C++ implementations provided +# in the paper titled "Robust membrane detection based on tensor voting for electron +# tomography" by Antonio Martinez-Sanchez, Inmaculada Garcia, Shoh Asano, Vladan Lucic, +# and Jose-Jesus Fernandez, published in the Journal of Structural Biology, +# Volume 186, Issue 1, 2014, Pages 49-61. The original work can be accessed via +# https://doi.org/10.1016/j.jsb.2014.02.015 and is used under conditions that adhere +# to the original licensing agreements. For details on the original license, refer to +# the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495. +# --------------------------------------------------------------------------------- +import numpy as np +import scipy.ndimage as ndimage +import torch + +from membrain_seg.segmentation.dataloading.data_utils import load_tomogram +from membrain_seg.segmentation.skeletonization.diff3d import ( + compute_gradients, + compute_hessian, +) +from membrain_seg.segmentation.skeletonization.eig3d import ( + batch_mask_eigendecomposition_3d, +) +from membrain_seg.segmentation.skeletonization.nonmaxsup import nonmaxsup +from membrain_seg.segmentation.training.surface_dice import apply_gaussian_filter + + +def skeletonization(label_path: str, batch_size: int) -> np.ndarray: + """ + Perform skeletonization on a tomogram segmentation. + + This function reads a segmentation file (label_path). It performs skeletonization on + the segmentation where the non-zero labels represent the structures of interest. + The resultan skeleton is saved with '_skel' appended after the filename. + + Parameters + ---------- + label_path : str + Path to the input tomogram segmentation file. + batch_size : int + The number of elements to process in one batch during eigen decomposition. + Useful for managing memory usage. + + Returns + ------- + ndarray + Returns the skeletonized image as a numpy array. + + Notes + ----- + The skeletonization is based on the computation of the distance transform + of the non-zero regions (foreground), followed by an eigenvalue analysis + of the Hessian matrix of the distance transform to identify ridge-like + structures corresponding to the centerlines of the segmented objects. + + Examples + -------- + >>> membrain skeletonize --label-path --out-folder + --batch-size 1000000 + This command runs the skeletonization process from the command line. + """ + # Read original segmentation + segmentation = load_tomogram(label_path) + segmentation = segmentation.data + + # Convert non-zero segmentation values to 1.0 + labels = (segmentation > 0) * 1.0 + + print("Distance transform on original labels.") + labels_dt = ndimage.distance_transform_edt(labels) * (-1) + + # Calculates partial derivative along 3 dimensions. + print("Computing partial derivative.") + gradX, gradY, gradZ = compute_gradients(labels_dt) + + # Calculates Hessian tensor + print("Computing Hessian tensor.") + hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ = compute_hessian( + gradX, gradY, gradZ + ) + hessians = [hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ] + del gradX, gradY, gradZ + + # Apply Gaussian filter with the same sigma value for all dimensions + print("Applying Gaussian filtering.") + # Load hessian tensors on GPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + filtered_hessian = [ + apply_gaussian_filter( + torch.from_numpy(comp).float().to(device).unsqueeze(0).unsqueeze(0), + kernel_size=9, + sigma=1.0, + ) + .squeeze() + .to("cpu") + for comp in hessians + ] + + # Solve Eigen problem + print("Computing Eigenvalues and Eigenvectors.") + print( + "In case the execution of the program is terminated unexpectedly, " + "attempt to rerun it using smaller segmentation patches" + "or give a specified batch size as input, e.g. batch_size=1000000." + ) + first_eigenvalue, first_eigenvector = batch_mask_eigendecomposition_3d( + filtered_hessian, batch_size, labels + ) + + # Non-maximum suppression + print("Genration of skeleton based on non-maximum suppression algorithm.") + first_eigenvalue = ndimage.gaussian_filter(first_eigenvalue, sigma=1) + skeleton = nonmaxsup( + first_eigenvalue, + first_eigenvector[:, :, :, 0], + first_eigenvector[:, :, :, 1], + first_eigenvector[:, :, :, 2], + labels, + ) + return skeleton