diff --git a/src/membrain_seg/segmentation/cli/train_cli.py b/src/membrain_seg/segmentation/cli/train_cli.py index 8cc6132..488c03d 100644 --- a/src/membrain_seg/segmentation/cli/train_cli.py +++ b/src/membrain_seg/segmentation/cli/train_cli.py @@ -36,6 +36,7 @@ def train( log_dir = "./logs" batch_size = 2 num_workers = 1 + on_the_fly_dataloading = False max_epochs = 1000 aug_prob_to_one = True use_deep_supervision = True @@ -47,6 +48,7 @@ def train( log_dir=log_dir, batch_size=batch_size, num_workers=num_workers, + on_the_fly_dataloading=on_the_fly_dataloading, max_epochs=max_epochs, aug_prob_to_one=aug_prob_to_one, use_deep_supervision=use_deep_supervision, @@ -76,6 +78,10 @@ def train_advanced( 8, help="Number of worker threads for loading data", ), + on_the_fly_dataloading: bool = Option( # noqa: B008 + False, + help="Whether to load data on the fly. This is useful for large datasets.", + ), max_epochs: int = Option( # noqa: B008 1000, help="Maximum number of epochs for training", @@ -131,6 +137,8 @@ def train_advanced( Number of samples per batch, by default 2. num_workers : int Number of worker threads for data loading, by default 1. + on_the_fly_dataloading : bool + Determines whether to load data on the fly, by default False. max_epochs : int Maximum number of training epochs, by default 1000. aug_prob_to_one : bool @@ -162,6 +170,7 @@ def train_advanced( log_dir=log_dir, batch_size=batch_size, num_workers=num_workers, + on_the_fly_dataloading=on_the_fly_dataloading, max_epochs=max_epochs, aug_prob_to_one=aug_prob_to_one, use_deep_supervision=use_deep_supervision, diff --git a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py index a099ac3..7b5aef9 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_dataset.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_dataset.py @@ -59,6 +59,7 @@ def __init__( train: bool = False, aug_prob_to_one: bool = False, patch_size: int = 160, + on_the_fly_loading: bool = False, ) -> None: """ Constructs all the necessary attributes for the CryoETMemSegDataset object. @@ -76,12 +77,16 @@ def __init__( to one or not. patch_size : int, default 160 The size of the patches to be extracted from the images. + on_the_fly_loading : bool, default False + A flag indicating whether the data should be loaded on the fly or not. """ self.train = train self.img_folder, self.label_folder = img_folder, label_folder self.patch_size = patch_size + self.on_the_fly_loading = on_the_fly_loading self.initialize_imgs_paths() - self.load_data() + if not self.on_the_fly_loading: + self.load_data() self.transforms = ( get_training_transforms(prob_to_one=aug_prob_to_one) if self.train @@ -104,13 +109,20 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: Dict[str, np.ndarray] A dictionary containing an image and its corresponding label. """ - idx_dict = { - "image": np.expand_dims(self.imgs[idx], 0), - "label": np.expand_dims(self.labels[idx], 0), - } + if self.on_the_fly_loading: + idx_dict = self.load_data_sample(idx) + idx_dict["image"] = np.expand_dims(idx_dict["image"], 0) + idx_dict["label"] = np.expand_dims(idx_dict["label"], 0) + ds_label = idx_dict["dataset"] + else: + idx_dict = { + "image": np.expand_dims(self.imgs[idx], 0), + "label": np.expand_dims(self.labels[idx], 0), + } + ds_label = self.dataset_labels[idx] idx_dict = self.get_random_crop(idx_dict) idx_dict = self.transforms(idx_dict) - idx_dict["dataset"] = self.dataset_labels[idx] + idx_dict["dataset"] = ds_label # transforms remove the dataset token return idx_dict def __len__(self) -> int: @@ -228,6 +240,27 @@ def get_random_crop(self, idx_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarr ), f"Image shape is {img.shape} instead of {self.patch_size}" return {"image": img, "label": label} + def load_data_sample(self, idx: int) -> Dict[str, np.ndarray]: + """ + Loads a single image-label pair from the dataset. + + Parameters + ---------- + idx : int + The index of the sample to be loaded. + + Returns + ------- + Dict[str, np.ndarray] + A dictionary containing an image and its corresponding label. + """ + label = read_nifti(self.data_paths[idx][1]) + img = read_nifti(self.data_paths[idx][0]) + label = np.transpose(label, (1, 2, 0)) + img = np.transpose(img, (1, 2, 0)) + ds_token = get_dataset_token(self.data_paths[idx][0]) + return {"image": img, "label": label, "dataset": ds_token} + def load_data(self) -> None: """ Loads image-label pairs into memory from the specified directories. @@ -240,18 +273,11 @@ def load_data(self) -> None: self.imgs = [] self.labels = [] self.dataset_labels = [] - for entry in tqdm(self.data_paths): - label = read_nifti( - entry[1] - ) # TODO: Change this to be applicable to .mrc images - img = read_nifti(entry[0]) - label = np.transpose( - label, (1, 2, 0) - ) # TODO: Needed? Probably no? z-axis should not matter - img = np.transpose(img, (1, 2, 0)) - self.imgs.append(img) - self.labels.append(label) - self.dataset_labels.append(get_dataset_token(entry[0])) + for entry_num in tqdm(range(len(self.data_paths))): + sample_dict = self.load_data_sample(entry_num) + self.imgs.append(sample_dict["image"]) + self.labels.append(sample_dict["label"]) + self.dataset_labels.append(sample_dict["dataset"]) def initialize_imgs_paths(self) -> None: """ diff --git a/src/membrain_seg/segmentation/dataloading/memseg_pl_datamodule.py b/src/membrain_seg/segmentation/dataloading/memseg_pl_datamodule.py index e73f597..3407d10 100644 --- a/src/membrain_seg/segmentation/dataloading/memseg_pl_datamodule.py +++ b/src/membrain_seg/segmentation/dataloading/memseg_pl_datamodule.py @@ -41,7 +41,14 @@ class MemBrainSegDataModule(pl.LightningDataModule): The test dataset. """ - def __init__(self, data_dir, batch_size, num_workers, aug_prob_to_one=False): + def __init__( + self, + data_dir, + batch_size, + num_workers, + on_the_fly_dataloading=False, + aug_prob_to_one=False, + ): """Initialization of data paths and data loaders. The data_dir should have the following structure: @@ -72,6 +79,7 @@ def __init__(self, data_dir, batch_size, num_workers, aug_prob_to_one=False): self.batch_size = batch_size self.num_workers = num_workers self.aug_prob_to_one = aug_prob_to_one + self.on_the_fly_dataloading = on_the_fly_dataloading def setup(self, stage: Optional[str] = None): """ @@ -91,14 +99,21 @@ def setup(self, stage: Optional[str] = None): label_folder=self.train_lab_dir, train=True, aug_prob_to_one=self.aug_prob_to_one, + on_the_fly_loading=self.on_the_fly_dataloading, ) self.val_dataset = CryoETMemSegDataset( - img_folder=self.val_img_dir, label_folder=self.val_lab_dir, train=False + img_folder=self.val_img_dir, + label_folder=self.val_lab_dir, + train=False, + on_the_fly_loading=self.on_the_fly_dataloading, ) if stage in (None, "test"): self.test_dataset = CryoETMemSegDataset( - self.data_dir, test=True, transform=self.transform + self.data_dir, + test=True, + transform=self.transform, + on_the_fly_loading=self.on_the_fly_dataloading, ) # TODO: How to do prediction? def train_dataloader(self) -> DataLoader: diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 8475e37..bce3044 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -22,6 +22,7 @@ def train( log_dir: str = "logs/", batch_size: int = 2, num_workers: int = 8, + on_the_fly_dataloading: bool = False, max_epochs: int = 1000, aug_prob_to_one: bool = False, use_deep_supervision: bool = False, @@ -48,6 +49,8 @@ def train( Number of samples per batch of input data. num_workers : int, optional Number of subprocesses to use for data loading. + on_the_fly_dataloading : bool, optional + If True, data is loaded on the fly. max_epochs : int, optional Maximum number of epochs to train for. aug_prob_to_one : bool, optional @@ -74,6 +77,7 @@ def train( log_dir=log_dir, batch_size=batch_size, num_workers=num_workers, + on_the_fly_dataloading=on_the_fly_dataloading, max_epochs=max_epochs, aug_prob_to_one=aug_prob_to_one, use_deep_supervision=use_deep_supervision, @@ -88,6 +92,7 @@ def train( data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, + on_the_fly_dataloading=on_the_fly_dataloading, aug_prob_to_one=aug_prob_to_one, ) diff --git a/src/membrain_seg/segmentation/training/training_param_summary.py b/src/membrain_seg/segmentation/training/training_param_summary.py index 67277c8..1837811 100644 --- a/src/membrain_seg/segmentation/training/training_param_summary.py +++ b/src/membrain_seg/segmentation/training/training_param_summary.py @@ -3,6 +3,7 @@ def print_training_parameters( log_dir: str = "logs/", batch_size: int = 2, num_workers: int = 8, + on_the_fly_dataloading: bool = False, max_epochs: int = 1000, aug_prob_to_one: bool = False, use_deep_supervision: bool = False, @@ -25,6 +26,8 @@ def print_training_parameters( Number of samples per batch of input data. num_workers : int, optional Number of subprocesses to use for data loading. + on_the_fly_dataloading : bool, optional + If True, data is loaded on the fly. max_epochs : int, optional Maximum number of epochs to train for. aug_prob_to_one : bool, optional @@ -68,6 +71,12 @@ def print_training_parameters( "loading.".format(num_workers) ) print("————————————————————————————————————————————————————————") + on_the_fly_status = "Enabled" if on_the_fly_dataloading else "Disabled" + print( + "On-the-Fly Data Loading:\n {} \n If enabled, data is loaded on " + "the fly.".format(on_the_fly_status) + ) + print("————————————————————————————————————————————————————————") print(f"Max Epochs:\n {max_epochs} \n Maximum number of training epochs.") print("————————————————————————————————————————————————————————") aug_status = "Enabled" if aug_prob_to_one else "Disabled"