Skip to content

Commit

Permalink
Logging statements (#87)
Browse files Browse the repository at this point in the history
* switch from print to logging statements

* concatenate strings
  • Loading branch information
LorenzLamm authored Jan 11, 2025
1 parent 7c37049 commit 079201e
Show file tree
Hide file tree
Showing 16 changed files with 178 additions and 74 deletions.
12 changes: 11 additions & 1 deletion src/membrain_seg/segmentation/cli/fine_tune_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from typing import List, Optional

from typer import Option
from typing_extensions import Annotated

from ..finetune import fine_tune as _fine_tune
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
from .cli import cli

logging.basicConfig(level=logging.INFO)


@cli.command(name="finetune", no_args_is_help=True)
def finetune(
Expand All @@ -25,6 +27,8 @@ def finetune(
),
):
"""
CLI for fine-tuning a pre-trained model.
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets.
Expand All @@ -51,6 +55,8 @@ def finetune(
using the provided model checkpoint.
The actual fine-tuning logic resides in the function '_fine_tune'.
"""
from ..finetune import fine_tune as _fine_tune

finetune_learning_rate = 1e-5
log_dir = "logs_finetune/"
batch_size = 2
Expand Down Expand Up @@ -157,6 +163,8 @@ def finetune_advanced(
),
):
"""
CLI for fine-tuning a pre-trained model with advanced options.
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets with more advanced options.
Expand Down Expand Up @@ -217,6 +225,8 @@ def finetune_advanced(
using the provided model checkpoint.
The actual fine-tuning logic resides in the function '_fine_tune'.
"""
from ..finetune import fine_tune as _fine_tune

_fine_tune(
pretrained_checkpoint_path=pretrained_checkpoint_path,
finetune_data_dir=finetune_data_dir,
Expand Down
45 changes: 34 additions & 11 deletions src/membrain_seg/segmentation/cli/segment_cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import logging
import os
from typing import List

from typer import Option

from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
store_tomogram,
)

from ..connected_components import connected_components as _connected_components
from ..segment import segment as _segment
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
from .cli import cli

logging.basicConfig(level=logging.INFO)


@cli.command(name="segment", no_args_is_help=True)
def segment(
Expand All @@ -27,7 +23,7 @@ def segment(
out_folder: str = Option( # noqa: B008
"./predictions", help="Path to the folder where segmentations should be stored."
),
rescale_patches: bool = Option( # noqa: B008
rescale_patches: bool = Option( # noqa: B008
False, help="Should patches be rescaled on-the-fly during inference?"
),
in_pixel_size: float = Option( # noqa: B008
Expand All @@ -36,7 +32,7 @@ def segment(
(default: 10 Angstrom)",
),
out_pixel_size: float = Option( # noqa: B008
10.,
10.0,
help="Pixel size of the output segmentation in Angstrom. \
(default: 10 Angstrom; should normally stay at 10 Angstrom)",
),
Expand Down Expand Up @@ -75,6 +71,20 @@ def segment(
membrain segment --tomogram-path <path-to-your-tomo>
--ckpt-path <path-to-your-model>
"""
from membrain_seg.segmentation.segment import segment as _segment

print("Segmenting tomogram", tomogram_path)
print("")
print(
"This can take several minutes. If you are bored, why not learn about \
what's happening under the hood by reading the MemBrain v2 preprint?"
)
print(
"MemBrain v2: an end-to-end tool for the analysis of membranes in \
cryo-electron tomography"
)
print("https://www.biorxiv.org/content/10.1101/2024.01.05.574336v1")
print("")
_segment(
tomogram_path=tomogram_path,
ckpt_path=ckpt_path,
Expand Down Expand Up @@ -115,6 +125,14 @@ def components(
membrain components --tomogram-path <path-to-your-tomo>
--connected-component-thres 5
"""
from membrain_seg.segmentation.connected_components import (
connected_components as _connected_components,
)
from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
store_tomogram,
)

segmentation = load_tomogram(segmentation_path)
conn_comps = _connected_components(
binary_seg=segmentation.data, size_thres=connected_component_thres
Expand Down Expand Up @@ -159,12 +177,17 @@ def thresholds(
indicating the threshold values in the default 'predictions' folder or
in the folder specified by the user.
"""
from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
store_tomogram,
)

scoremap = load_tomogram(scoremap_path)
score_data = scoremap.data
if not isinstance(thresholds, list):
thresholds = [thresholds]
for threshold in thresholds:
print("Thresholding at", threshold)
logging.info("Thresholding at" + str(threshold))
thresholded_data = score_data > threshold
segmentation = scoremap
segmentation.data = thresholded_data
Expand All @@ -174,4 +197,4 @@ def thresholds(
+ f"_threshold_{threshold}.mrc",
)
store_tomogram(filename=out_file, tomogram=segmentation)
print("Saved thresholded scoremap to", out_file)
logging.info("Saved thresholded scoremap to " + out_file)
28 changes: 21 additions & 7 deletions src/membrain_seg/segmentation/cli/ske_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@

from typer import Option

from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
store_tomogram,
)

from ..skeletonize import skeletonization as _skeletonization
from .cli import cli

logging.basicConfig(level=logging.INFO)


@cli.command(name="skeletonize", no_args_is_help=True)
def skeletonize(
Expand Down Expand Up @@ -55,7 +51,25 @@ def skeletonize(
membrain skeletonize --label-path <path> --out-folder <output-directory>
--batch-size <batch-size>
"""
# Assuming _skeletonization function is already defined and can handle batch_size
from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
store_tomogram,
)

from ..skeletonize import skeletonization as _skeletonization

print("Skeletonizing the segmentation")
print("")
print(
"This can take several minutes. If you are bored, why not learn about what's \
happening under the hood by reading the MemBrain v2 preprint?"
)
print(
"MemBrain v2: an end-to-end tool for the analysis of membranes in \
cryo-electron tomography"
)
print("https://www.biorxiv.org/content/10.1101/2024.01.05.574336v1")
print("")

segmentation = load_tomogram(label_path)
ske = _skeletonization(segmentation=segmentation.data, batch_size=batch_size)
Expand Down
36 changes: 35 additions & 1 deletion src/membrain_seg/segmentation/cli/train_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from typing import List, Optional

from typer import Option
from typing_extensions import Annotated

from ..train import train as _train
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
from .cli import cli

logging.basicConfig(level=logging.INFO)


@cli.command(name="train", no_args_is_help=True)
def train(
Expand All @@ -33,6 +35,22 @@ def train(
The actual training logic resides in the function '_train'.
"""
from ..train import train as _train

print("Training the model")
print("")
print("")
print(
"This will take forever. If you are bored, why not learn about what's \
happening under the hood by reading the MemBrain v2 preprint?"
)
print(
"MemBrain v2: an end-to-end tool for the analysis of membranes in \
cryo-electron tomography"
)
print("https://www.biorxiv.org/content/10.1101/2024.01.05.574336v1")
print("")

log_dir = "./logs"
batch_size = 2
num_workers = 1
Expand Down Expand Up @@ -165,6 +183,22 @@ def train_advanced(
The actual training logic resides in the function '_train'.
"""
from ..train import train as _train

print("Training the model")
print("")
print("")
print(
"This will take forever. If you are bored, why not learn about what's \
happening under the hood by reading the MemBrain v2 preprint?"
)
print(
"MemBrain v2: an end-to-end tool for the analysis of membranes in \
cryo-electron tomography"
)
print("https://www.biorxiv.org/content/10.1101/2024.01.05.574336v1")
print("")

_train(
data_dir=data_dir,
log_dir=log_dir,
Expand Down
12 changes: 7 additions & 5 deletions src/membrain_seg/segmentation/connected_components.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import numpy as np
from scipy import ndimage

Expand Down Expand Up @@ -31,17 +33,17 @@ def connected_components(binary_seg: np.ndarray, size_thres: int = None):
the number of connected components found. All background voxels
are zero.
"""
print("Computing connected components.")
logging.info("Computing connected components.")
# Get 3D connected components
structure = np.ones((3, 3, 3))
labeled_array, num_features = ndimage.label(binary_seg, structure=structure)

# remove small clusters
if size_thres is not None and size_thres > 1:
print(
"Removing components smaller than",
size_thres,
"voxels. (This can take a while)",
logging.info(
"Removing components smaller than "
+ str(size_thres)
+ " voxels. (This can take a while)",
)
sizes = ndimage.sum(binary_seg, labeled_array, range(1, num_features + 1))
too_small = np.nonzero(sizes < size_thres)[0] + 1 # features labeled from 1
Expand Down
14 changes: 13 additions & 1 deletion src/membrain_seg/segmentation/dataloading/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import csv
import logging
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -200,7 +202,7 @@ def store_segmented_tomograms(
data=predictions_np_thres, header=mrc_header, voxel_size=voxel_size
)
store_tomogram(out_file_thres, out_tomo)
print("MemBrain has finished segmenting your tomogram.")
logging.info("MemBrain has finished segmenting your tomogram.")
return out_file_thres


Expand Down Expand Up @@ -286,6 +288,16 @@ def load_tomogram(
and voxel size.
"""
warnings.filterwarnings(
"ignore",
message="Map ID string not found - \
not an MRC file, or file is corrupt",
)
warnings.filterwarnings(
"ignore",
message="Unrecognised machine stamp: \
0x00 0x00 0x00 0x00",
)
with mrcfile.open(filename, permissive=True) as tomogram:
data = tomogram.data.copy()
data = np.transpose(data, (2, 1, 0))
Expand Down
3 changes: 2 additions & 1 deletion src/membrain_seg/segmentation/dataloading/memseg_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from typing import Dict

Expand Down Expand Up @@ -269,7 +270,7 @@ def load_data(self) -> None:
-----
This function assumes the image and label files are in NIFTI format.
"""
print("Loading images into dataset.")
logging.info("Loading images into dataset.")
self.imgs = []
self.labels = []
self.dataset_labels = []
Expand Down
8 changes: 5 additions & 3 deletions src/membrain_seg/segmentation/finetune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
Expand Down Expand Up @@ -105,13 +107,13 @@ def fine_tune(
surf_dice_weight=surf_dice_weight,
surf_dice_tokens=surf_dice_tokens,
)
print("————————————————————————————————————————————————————————")
print(
logging.info("————————————————————————————————————————————————————————")
logging.info(
f"Pretrained Checkpoint:\n"
f" '{pretrained_checkpoint_path}' \n"
f" Path to the pretrained model checkpoint."
)
print("\n")
logging.info("\n")

# Initialize the data module with fine-tuning datasets
# New data for finetuning and old data for validation
Expand Down
15 changes: 8 additions & 7 deletions src/membrain_seg/segmentation/networks/unet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from functools import partial
from typing import Dict, Tuple

Expand Down Expand Up @@ -258,9 +259,9 @@ def on_train_epoch_end(self):
self.log("train_surf_dice", mean_train_surf_dice)

self.training_step_outputs = []
print("EPOCH Training loss", mean_train_loss.item())
print("EPOCH Training acc", mean_train_acc.item())
print("EPOCH Training surface dice", mean_train_surf_dice.item())
logging.info("EPOCH Training loss", mean_train_loss.item())
logging.info("EPOCH Training acc", mean_train_acc.item())
logging.info("EPOCH Training surface dice", mean_train_surf_dice.item())
# Accuracy not the most informative metric, but a good sanity check
return {"train_loss": mean_train_loss}

Expand Down Expand Up @@ -334,8 +335,8 @@ def on_validation_epoch_end(self):
self.log("val_accuracy", mean_val_acc)

self.validation_step_outputs = []
print("EPOCH Validation loss", mean_val_loss.item())
print("EPOCH Validation dice", mean_val_dice)
print("EPOCH Validation surface dice", mean_val_surf_dice.item())
print("EPOCH Validation acc", mean_val_acc.item())
logging.info("EPOCH Validation loss", mean_val_loss.item())
logging.info("EPOCH Validation dice", mean_val_dice)
logging.info("EPOCH Validation surface dice", mean_val_surf_dice.item())
logging.info("EPOCH Validation acc", mean_val_acc.item())
return {"val_loss": mean_val_loss, "val_metric": mean_val_dice}
Loading

0 comments on commit 079201e

Please sign in to comment.