-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
On the fly rescaling (GPU) #64
Merged
Merged
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
f294ab3
Remove local variables display from pretty print
LorenzLamm 64d9309
On the fly rescaling
LorenzLamm 8fa5d53
Read pixel size from header by default
LorenzLamm 9fa3df6
Fix voxel size from header loading
LorenzLamm 3819621
add rescaling option to CLI
LorenzLamm 6bbdf21
remove print statements
LorenzLamm d1b21c0
Move rescaling to torch GPU
LorenzLamm e1db319
remove hard-coded GPU requirement
LorenzLamm 6683a3a
Fix test time augmentation with rescaling SWInferer
LorenzLamm e9d4072
Merge branch 'teamtomo:main' into on-the-fly-rescaling
LorenzLamm ef62e08
read device from model by default
LorenzLamm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
src/membrain_seg/segmentation/networks/inference_unet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet | ||
from membrain_seg.tomo_preprocessing.matching_utils.px_matching_utils import ( | ||
fourier_cropping_torch, | ||
fourier_extend_torch, | ||
) | ||
|
||
|
||
def rescale_tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for re-scaling the prediction scores |
||
sample: torch.Tensor, target_size: tuple, mode="trilinear" | ||
) -> torch.Tensor: | ||
""" | ||
Rescales the input tensor by given factors using interpolation. | ||
|
||
Parameters | ||
---------- | ||
sample : torch.Tensor | ||
The input data as a torch tensor. | ||
target_size : tuple | ||
The target size of the rescaled tensor. | ||
mode : str, optional | ||
The mode of interpolation ('nearest', 'linear', 'bilinear', | ||
'bicubic', or 'trilinear'). Default is 'trilinear'. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
The rescaled tensor. | ||
""" | ||
# Add batch and channel dimensions | ||
sample = sample.unsqueeze(0).unsqueeze(0) | ||
|
||
# Apply interpolation | ||
rescaled_sample = F.interpolate( | ||
sample, size=target_size, mode=mode, align_corners=False | ||
) | ||
|
||
return rescaled_sample.squeeze(0).squeeze(0) | ||
|
||
|
||
class PreprocessedSemanticSegmentationUnet(SemanticSegmentationUnet): | ||
"""U-Net with rescaling preprocessing. | ||
|
||
This class extends the SemanticSegmentationUnet class by adding | ||
preprocessing and postprocessing steps. The preprocessing step | ||
rescales the input to the target shape, and the postprocessing | ||
step rescales the output to the original shape. | ||
All of this is done on the GPU if available. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*args, | ||
rescale_patches: bool = False, # Should patches be rescaled? | ||
target_shape: Tuple[int, int, int] = (160, 160, 160), | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
# Store the preprocessing parameters | ||
self.rescale_patches = rescale_patches | ||
self.target_shape = target_shape | ||
|
||
def preprocess(self, x): | ||
"""Preprocess the input to the network. | ||
|
||
In this case, we rescale the input to the target shape. | ||
""" | ||
rescaled_samples = [] | ||
for sample in x: | ||
sample = sample[0] # only use the first channel | ||
if self.rescale_patches: | ||
if sample.shape[0] > self.target_shape[0]: | ||
sample = fourier_cropping_torch(sample, self.target_shape) | ||
elif sample.shape[0] < self.target_shape[0]: | ||
sample = fourier_extend_torch(sample, self.target_shape) | ||
rescaled_samples.append(sample.unsqueeze(0)) | ||
rescaled_samples = torch.stack(rescaled_samples, dim=0) | ||
return rescaled_samples | ||
|
||
def postprocess(self, x, orig_shape): | ||
"""Postprocess the output of the network. | ||
|
||
In this case, we rescale the output to the original shape. | ||
""" | ||
rescaled_samples = [] | ||
for sample in x: | ||
sample = sample[0] # only use first channel | ||
if self.rescale_patches: | ||
sample = rescale_tensor(sample, orig_shape, mode="trilinear") | ||
rescaled_samples.append(sample.unsqueeze(0)) | ||
rescaled_samples = torch.stack(rescaled_samples, dim=0) | ||
return rescaled_samples | ||
|
||
def forward(self, x): | ||
"""Forward pass through the network.""" | ||
orig_shape = x.shape[2:] | ||
preprocessed_x = self.preprocess(x) | ||
predicted = super().forward(preprocessed_x) | ||
postprocessed_predicted = self.postprocess(predicted[0], orig_shape) | ||
# Return list to be compatible with deep supervision outputs | ||
return [postprocessed_predicted] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
src/membrain_seg/tomo_preprocessing/matching_utils/px_matching_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how you feel, but I don't like default printing of local variables for debugging most of the time. It's a bit annoying because it prints the entire model weights.
Of course, also provides more detailed information, but maybe this could be an advanced option?