Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On the fly rescaling (GPU) #64

Merged
merged 11 commits into from
May 23, 2024
1 change: 1 addition & 0 deletions src/membrain_seg/segmentation/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def list_commands(self, ctx: Context):
add_completion=False,
no_args_is_help=True,
rich_markup_mode="rich",
pretty_exceptions_show_locals=False
Copy link
Collaborator Author

@LorenzLamm LorenzLamm Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how you feel, but I don't like default printing of local variables for debugging most of the time. It's a bit annoying because it prints the entire model weights.

Of course, also provides more detailed information, but maybe this could be an advanced option?

)
OPTION_PROMPT_KWARGS = {"prompt": True, "prompt_required": True}
PKWARGS = OPTION_PROMPT_KWARGS
Expand Down
16 changes: 16 additions & 0 deletions src/membrain_seg/segmentation/cli/segment_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def segment(
out_folder: str = Option( # noqa: B008
"./predictions", help="Path to the folder where segmentations should be stored."
),
rescale_patches: bool = Option( # noqa: B008
False, help="Should patches be rescaled on-the-fly during inference?"
),
in_pixel_size: float = Option( # noqa: B008
None,
help="Pixel size of the input tomogram in Angstrom. \
(default: 10 Angstrom)",
),
out_pixel_size: float = Option( # noqa: B008
10.,
help="Pixel size of the output segmentation in Angstrom. \
(default: 10 Angstrom; should normally stay at 10 Angstrom)",
),
store_probabilities: bool = Option( # noqa: B008
False, help="Should probability maps be output in addition to segmentations?"
),
Expand Down Expand Up @@ -66,6 +79,9 @@ def segment(
tomogram_path=tomogram_path,
ckpt_path=ckpt_path,
out_folder=out_folder,
rescale_patches=rescale_patches,
in_pixel_size=in_pixel_size,
out_pixel_size=out_pixel_size,
store_probabilities=store_probabilities,
store_connected_components=store_connected_components,
connected_component_thres=connected_component_thres,
Expand Down
105 changes: 105 additions & 0 deletions src/membrain_seg/segmentation/networks/inference_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Tuple

import torch
import torch.nn.functional as F

from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.tomo_preprocessing.matching_utils.px_matching_utils import (
fourier_cropping_torch,
fourier_extend_torch,
)


def rescale_tensor(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for re-scaling the prediction scores

sample: torch.Tensor, target_size: tuple, mode="trilinear"
) -> torch.Tensor:
"""
Rescales the input tensor by given factors using interpolation.

Parameters
----------
sample : torch.Tensor
The input data as a torch tensor.
target_size : tuple
The target size of the rescaled tensor.
mode : str, optional
The mode of interpolation ('nearest', 'linear', 'bilinear',
'bicubic', or 'trilinear'). Default is 'trilinear'.

Returns
-------
torch.Tensor
The rescaled tensor.
"""
# Add batch and channel dimensions
sample = sample.unsqueeze(0).unsqueeze(0)

# Apply interpolation
rescaled_sample = F.interpolate(
sample, size=target_size, mode=mode, align_corners=False
)

return rescaled_sample.squeeze(0).squeeze(0)


class PreprocessedSemanticSegmentationUnet(SemanticSegmentationUnet):
"""U-Net with rescaling preprocessing.

This class extends the SemanticSegmentationUnet class by adding
preprocessing and postprocessing steps. The preprocessing step
rescales the input to the target shape, and the postprocessing
step rescales the output to the original shape.
All of this is done on the GPU if available.
"""

def __init__(
self,
*args,
rescale_patches: bool = False, # Should patches be rescaled?
target_shape: Tuple[int, int, int] = (160, 160, 160),
**kwargs,
):
super().__init__(*args, **kwargs)
# Store the preprocessing parameters
self.rescale_patches = rescale_patches
self.target_shape = target_shape

def preprocess(self, x):
"""Preprocess the input to the network.

In this case, we rescale the input to the target shape.
"""
rescaled_samples = []
for sample in x:
sample = sample[0] # only use the first channel
if self.rescale_patches:
if sample.shape[0] > self.target_shape[0]:
sample = fourier_cropping_torch(sample, self.target_shape)
elif sample.shape[0] < self.target_shape[0]:
sample = fourier_extend_torch(sample, self.target_shape)
rescaled_samples.append(sample.unsqueeze(0))
rescaled_samples = torch.stack(rescaled_samples, dim=0)
return rescaled_samples

def postprocess(self, x, orig_shape):
"""Postprocess the output of the network.

In this case, we rescale the output to the original shape.
"""
rescaled_samples = []
for sample in x:
sample = sample[0] # only use first channel
if self.rescale_patches:
sample = rescale_tensor(sample, orig_shape, mode="trilinear")
rescaled_samples.append(sample.unsqueeze(0))
rescaled_samples = torch.stack(rescaled_samples, dim=0)
return rescaled_samples

def forward(self, x):
"""Forward pass through the network."""
orig_shape = x.shape[2:]
preprocessed_x = self.preprocess(x)
predicted = super().forward(preprocessed_x)
postprocessed_predicted = self.postprocess(predicted[0], orig_shape)
# Return list to be compatible with deep supervision outputs
return [postprocessed_predicted]
69 changes: 54 additions & 15 deletions src/membrain_seg/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import torch
from monai.inferers import SlidingWindowInferer

from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.segmentation.networks.inference_unet import (
PreprocessedSemanticSegmentationUnet,
)
from membrain_seg.tomo_preprocessing.matching_utils.px_matching_utils import (
determine_output_shape,
)

from .dataloading.data_utils import (
load_data_for_inference,
Expand All @@ -16,6 +21,9 @@ def segment(
tomogram_path,
ckpt_path,
out_folder,
rescale_patches=False,
in_pixel_size=None,
out_pixel_size=10.0,
store_probabilities=False,
sw_roi_size=160,
store_connected_components=False,
Expand All @@ -40,6 +48,12 @@ def segment(
Path to the trained model checkpoint file.
out_folder : str
Path to the folder where the output segmentations should be stored.
rescale_patches : bool, optional
If True, rescale the patches to the output pixel size (default is False).
in_pixel_size : float, optional
Pixel size of the input tomogram in Angstrom (default is None).
out_pixel_size : float, optional
Pixel size of the output segmentation in Angstrom (default is 10.0).
store_probabilities : bool, optional
If True, store the predicted probabilities along with the segmentations
(default is False).
Expand Down Expand Up @@ -78,10 +92,13 @@ def segment(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model and load trained weights from checkpoint
pl_model = SemanticSegmentationUnet.load_from_checkpoint(
pl_model = PreprocessedSemanticSegmentationUnet.load_from_checkpoint(
model_checkpoint, map_location=device, strict=False
)
pl_model.to(device)
if sw_roi_size % 32 != 0:
raise OSError("Sliding window size must be multiple of 32°!")
pl_model.target_shape = (sw_roi_size, sw_roi_size, sw_roi_size)

# Preprocess the new data
new_data_path = tomogram_path
Expand All @@ -91,12 +108,34 @@ def segment(
)
new_data = new_data.to(torch.float32)

if rescale_patches:
# Rescale patches if necessary
if in_pixel_size is None:
in_pixel_size = voxel_size.x
if in_pixel_size == 0.0:
raise ValueError(
"Input pixel size is 0.0. Please specify the pixel size manually."
)
if in_pixel_size == 1.0:
print(
"WARNING: Input pixel size is 1.0. Looks like a corrupt header.",
"Please specify the pixel size manually.",
)
pl_model.rescale_patches = in_pixel_size != out_pixel_size

# Determine the sliding window size according to the input and output pixel size
sw_roi_size = determine_output_shape(
# switch in and out pixel size to get SW shape
pixel_size_in=out_pixel_size,
pixel_size_out=in_pixel_size,
orig_shape=(sw_roi_size, sw_roi_size, sw_roi_size),
)
sw_roi_size = sw_roi_size[0]

# Put the model into evaluation mode
pl_model.eval()

# Perform sliding window inference on the new data
if sw_roi_size % 32 != 0:
raise OSError("Sliding window size must be multiple of 32°!")
roi_size = (sw_roi_size, sw_roi_size, sw_roi_size)
sw_batch_size = 1
inferer = SlidingWindowInferer(
Expand All @@ -110,20 +149,20 @@ def segment(

# Perform test time augmentation (8-fold mirroring)
predictions = torch.zeros_like(new_data)
print("Performing 8-fold test-time augmentation.")
if test_time_augmentation:
print(
"Performing 8-fold test-time augmentation.",
"I.e. the following bar will run 8 times.",
)
for m in range(8 if test_time_augmentation else 1):
with torch.no_grad():
with torch.cuda.amp.autocast():
predictions += (
get_mirrored_img(
inferer(
get_mirrored_img(new_data.clone(), m).to(device), pl_model
)[0],
m,
)
.detach()
.cpu()
)
mirrored_input = get_mirrored_img(new_data.clone(), m).to(device)
mirrored_pred = inferer(mirrored_input, pl_model)
if not isinstance(mirrored_pred, list):
mirrored_pred = [mirrored_pred]
correct_pred = get_mirrored_img(mirrored_pred[0], m)
predictions += correct_pred.detach().cpu()
if test_time_augmentation:
predictions /= 8.0

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,95 @@
from typing import Tuple, Union

import numpy as np
import torch
import torch.fft
from scipy.fft import fftn, ifftn
from scipy.ndimage import distance_transform_edt


def fourier_cropping_torch(data: torch.Tensor, new_shape: tuple) -> torch.Tensor:
"""
Fourier cropping adapted for PyTorch and GPU, without smoothing functionality.

Parameters
----------
data : torch.Tensor
The input data as a 3D torch tensor on GPU.
new_shape : tuple
The target shape for the cropped data as a tuple (x, y, z).

Returns
-------
torch.Tensor
The resized data as a 3D torch tensor.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
data = data.to(device)

# Calculate the FFT of the input data
data_fft = torch.fft.fftn(data)
data_fft = torch.fft.fftshift(data_fft)

# Calculate the cropping indices
original_shape = torch.tensor(data.shape, device=device)
new_shape = torch.tensor(new_shape, device=device)
start_indices = (original_shape - new_shape) // 2
end_indices = start_indices + new_shape

# Crop the filtered FFT data
cropped_fft = data_fft[
start_indices[0] : end_indices[0],
start_indices[1] : end_indices[1],
start_indices[2] : end_indices[2],
]

unshifted_cropped_fft = torch.fft.ifftshift(cropped_fft)

# Calculate the inverse FFT of the cropped data
resized_data = torch.real(torch.fft.ifftn(unshifted_cropped_fft))

return resized_data


def fourier_extend_torch(data: torch.Tensor, new_shape: tuple) -> torch.Tensor:
"""
Fourier padding adapted for PyTorch and GPU, without smoothing functionality.

Parameters
----------
data : torch.Tensor
The input data as a 3D torch tensor on GPU.
new_shape : tuple
The target shape for the extended data as a tuple (x, y, z).

Returns
-------
torch.Tensor
The resized data as a 3D torch tensor.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
data = data.to(device)

data_fft = torch.fft.fftn(data)
data_fft = torch.fft.fftshift(data_fft)

padding = [
(new_dim - old_dim) // 2 for old_dim, new_dim in zip(data.shape, new_shape)
]
padded_fft = torch.nn.functional.pad(
data_fft,
pad=[pad for pair in zip(padding, padding) for pad in pair],
mode="constant",
)

unshifted_padded_fft = torch.fft.ifftshift(padded_fft)

# Calculate the inverse FFT of the cropped data
resized_data = torch.real(torch.fft.ifftn(unshifted_padded_fft))

return resized_data


def smooth_cosine_dropoff(mask: np.ndarray, decay_width: float) -> np.ndarray:
"""
Apply a smooth cosine drop-off to a given mask.
Expand Down