diff --git a/src/murfey/util/lif.py b/src/murfey/util/lif.py index 2be7a089..358200f2 100644 --- a/src/murfey/util/lif.py +++ b/src/murfey/util/lif.py @@ -6,11 +6,11 @@ from __future__ import annotations import logging +import multiprocessing as mp from pathlib import Path from typing import Generator, List, Optional, Tuple from xml.etree import ElementTree as ET -# import matplotlib.pyplot as plt import numpy as np from readlif.reader import LifFile from tifffile import imwrite @@ -28,8 +28,8 @@ def get_xml_metadata( save_xml: Optional[Path] = None, ) -> ET.Element: """ - Extracts and returns the file metadata as a formatted XML tree, and optionally - saves it as an XML file to the specified file path. + Extracts and returns the file metadata as a formatted XML tree. Provides option + to save it as an XML file to the specified file path """ # Use readlif function to get XML metadata @@ -80,7 +80,7 @@ def _find_elements_recursively( return elem_list -def raise_bit_depth_error(bit_depth: int): +def raise_BitDepthError(bit_depth: int): """ Raises an exception if the bit depth value provided is not one that NumPy can handle. @@ -95,14 +95,15 @@ def raise_bit_depth_error(bit_depth: int): def change_bit_depth( array: np.ndarray, - bit_depth: int, + target_bit_depth: int, ) -> np.ndarray: """ Change the bit depth of the array without changing the values (barring rounding). """ - # Use shorter terms + # Use shorter terms in function arr = array + bit_depth = target_bit_depth # NumPy defaults to float64; revert back to unsigned int if bit_depth == 8: @@ -114,7 +115,7 @@ def change_bit_depth( elif bit_depth == 64: arr = arr.astype(np.uint64) else: - raise_bit_depth_error(bit_depth) + raise_BitDepthError(bit_depth) return arr @@ -131,9 +132,7 @@ def rescale_across_channel( # Check that bit depth is valid before processing even begins if not any(bit_depth == b for b in [8, 16, 32, 64]): - raise_bit_depth_error(bit_depth) - else: - pass # Proceed to rest of function + raise_BitDepthError(bit_depth) # Use shorter variable names arr = array @@ -158,14 +157,8 @@ def rescale_across_channel( 2**bit_depth - 1 ) # Ensure data points don't exceed bit depth (max bit is 2**n - 1) - # This step probably not needed - # Overwrite values that exceed current channel bit depth - # arr[arr >= (2**bit_depth - 1)] = ( - # 2**bit_depth - 1 - # ) - # Change bit depth back to initial one - arr = change_bit_depth(arr, bit_depth) + arr = change_bit_depth(array=arr, target_bit_depth=bit_depth) return arr @@ -174,7 +167,7 @@ def rescale_to_bit_depth( array: np.ndarray, initial_bit_depth: int, target_bit_depth: int, -) -> Tuple[np.ndarray, int]: +) -> np.ndarray: """ Rescales the pixel values of the array to fit within the desired channel bit depth. Returns the array and the target bit depth as a tuple. @@ -187,24 +180,192 @@ def rescale_to_bit_depth( # Check that target bit depth is allowed if not any(bit_final == b for b in [8, 16, 32, 64]): - raise_bit_depth_error(bit_final) + raise_BitDepthError(bit_final) # Rescale (DIVIDE BEFORE MULTIPLY) arr = (arr / (2**bit_init - 1)) * (2**bit_final - 1) - # This step probably not needed anymore - # Overwrite pixels that exceed channel bit depth - # arr[arr >= (2**bit_final - 1)] = 2**bit_final - 1 - # Change to correct unsigned integer type - arr = change_bit_depth(arr, bit_final) + arr = change_bit_depth(array=arr, target_bit_depth=bit_final) + + return arr + - return arr, bit_final +def process_image_stack( + file: Path, + scene_num: int, + metadata: ET.Element, + save_dir: Path, +): + """ + Takes the LIF file and its corresponding metadata and loads the relevant sub-stack, + with each channel as its own array. Rescales their intensity values to utilise the + whole channel, scales them down to 8-bit, then saves each each array as a separate + TIFF image stack. + """ + + # Load LIF file + file_name = file.stem.replace(" ", "_") + image = LifFile(str(file)).get_image(scene_num) + + # Get name of sub-image + img_name = metadata.attrib["Name"].replace(" ", "_") # Remove spaces + logger.info(f"Processing {file_name}-{img_name}") + + # Create save dirs for TIFF files and their metadata + img_dir = save_dir / img_name + img_xml_dir = img_dir / "metadata" + for folder in [img_dir, img_xml_dir]: + if not folder.exists(): + folder.mkdir(parents=True) + logger.info(f"Created {folder}") + else: + logger.info(f"{folder} already exists") + + # Save image stack XML metadata (all channels together) + img_xml_file = img_xml_dir / (img_name + ".xml") + metadata_tree = ET.ElementTree(metadata) + ET.indent(metadata_tree, " ") + metadata_tree.write(img_xml_file, encoding="utf-8") + logger.info(f"Image stack metadata saved to {img_xml_file}") + + # Load channels + channel_elem = metadata.findall( + "Data/Image/ImageDescription/Channels/ChannelDescription" + ) + channels: list = [ + channel_elem[c].attrib["LUTName"].lower() for c in range(len(channel_elem)) + ] + + # Load timestamps and dimensions + # Might be useful in the future + # timestamps = elem.find("Data/Image/TimeStampList") + # dimensions = elem.findall( + # "Data/Image/ImageDescription/Dimensions/DimensionDescription" + # ) + + # Generate slice labels for later + num_frames = image.dims.z + image_labels = [f"{f}" for f in range(num_frames)] + + # Get x, y, and z scales + # Get resolution (pixels per um) + x_res = image.scale[0] + y_res = image.scale[1] + + # Get pixel size (um per pixel) + # Might be useful in the future + # x_scale = 1 / x_res + # y_scale = 1 / y_res + + # Check that depth axis exists + z_res: float = image.scale[2] if num_frames > 1 else float(0) # Pixels per um + z_scale: float = 1 / z_res if num_frames > 1 else float(0) # um per pixel + + # Process channels as individual TIFFs + for c in range(len(channels)): + + # Get color + color = channels[c] + logger.info(f"Processing {color} channel") + + # Load image stack to array + logger.info("Loading image stack") + for z in range(num_frames): + frame = image.get_frame(z=z, t=0, c=c) # PIL object; array-like + if z == 0: + arr = np.array([frame]) + else: + arr = np.append(arr, [frame], axis=0) + logger.info( + f"{file_name}-{img_name}-{color} has the dimensions {np.shape(arr)} \n" + f"Min value: {np.min(arr)} \n" + f"Max value: {np.max(arr)} \n" + ) + + # Initial rescaling if bit depth not 8, 16, 32, or 64-bit + bit_depth = image.bit_depth[c] # Initial bit depth + if not any(bit_depth == b for b in [8, 16, 32, 64]): + logger.info(f"{bit_depth}-bit is non-standard; converting to 16-bit") + arr = ( + rescale_to_bit_depth( + array=arr, initial_bit_depth=bit_depth, target_bit_depth=16 + ) + if np.max(arr) > 0 + else change_bit_depth( + array=arr, + target_bit_depth=16, + ) + ) + bit_depth = 16 # Overwrite + + # Rescale intensity values for fluorescent channels + # Currently pre-emptively converting for all coloured ones + if any( + color in key + for key in [ + "blue", # Not tested + "cyan", # Not tested + "green", + "magenta", # Not tested + "red", + "yellow", # Not tested + ] + ): + logger.info(f"Rescaling {color} channel across channel depth") + arr = ( + rescale_across_channel( + array=arr, + bit_depth=bit_depth, + percentile_range=(0.5, 99.5), + round_to=16, + ) + if np.max(arr) > 0 + else arr + ) + + # Convert to 8-bit + logger.info("Converting to 8-bit image") + bit_depth_new = 8 + arr = ( + rescale_to_bit_depth( + array=arr, + initial_bit_depth=bit_depth, + target_bit_depth=bit_depth_new, + ) + if np.max(arr) > 0 + else change_bit_depth( + array=arr, + target_bit_depth=bit_depth_new, + ) + ) + + # Save as a greyscale TIFF + save_name = img_dir.joinpath(color + ".tiff") + logger.info(f"Saving {color} image as {save_name}") + imwrite( + save_name, + arr, + imagej=True, # ImageJ compatible + photometric="minisblack", # Grayscale image + shape=np.shape(arr), + dtype=arr.dtype, + resolution=(x_res * 10**6 / 10**6, y_res * 10**6 / 10**6), + metadata={ + "spacing": z_scale, + "unit": "micron", + "axes": "ZYX", + "Labels": image_labels, + }, + ) + + return True def convert_lif_to_tiff( file: Path, - root_folder: str, # The name of the folder to be treated as the root + root_folder: str, # Name of the folder under which all raw LIF files are stored + number_of_processes: int = 1, # For parallel processing ): """ Takes a LIF file, extracts its metadata as an XML tree, then parses through the @@ -212,8 +373,7 @@ def convert_lif_to_tiff( image stack. It uses information stored in the XML metadata to name the individual image stacks. - FOLDER STRUCTURE - ================ + FOLDER STRUCTURE: Here is the folder structure of a typical DLS eBIC experiment session, with the folders created as part of the workflow shown as well. @@ -234,45 +394,40 @@ def convert_lif_to_tiff( | |_ metadata <- Individual XML files saved here (not yet implemented) """ - # Set up new directories - # Identify the root directory - root_parts = [] - for p in file.parts: # Iterate through parts until hitting root folder - if p.lower() == root_folder.lower(): # Eliminate case-sensitivity - break - root_parts.append(p) - else: + # Validate processor count input + num_procs = number_of_processes # Use shorter phrase in script + if num_procs < 1: + logger.warning("Processor count set to zero or less; resetting to 1") + num_procs = 1 + + # Folder for processed files with same structure as old one + file_name = file.stem.replace(" ", "_") # Replace spaces + path_parts = list(file.parts) + new_root_folder = "processed" + # Rewrite string in-place + for p in range(len(path_parts)): + part = path_parts[p] + # Omit initial "/" in Linux file systems for subsequent rejoining + if part == "/": + path_parts[p] = "" + # Rename designated raw folder to "processed" + if part.lower() == root_folder.lower(): # Remove case-sensitivity + path_parts[p] = new_root_folder + break # Do for first instance only + # If specified folder not found by end of string, log as error + if new_root_folder not in path_parts: logger.error( - f"Subpath {sanitise(root_folder)} was not found in image path {sanitise(str(file))}" + f"Subpath {sanitise(root_folder)} was not found in image path " + f"{sanitise(str(file))}" ) return None - root_dir = Path("/".join(root_parts)) # Session ID folder - - # Get remaining path to file from root folder - child_parts = [] # Path from the root directory to the file - for p in reversed(file.parts): - # Append everything up until the root directory - if p == root_folder: - break - child_parts.append(p) - else: - logger.error( - f"Subpath {sanitise(root_folder)} was not found in image path {sanitise(str(file))}" - ) - child_path = Path( - "/".join(reversed(child_parts)) - ) # Reverse it to get the right order - - # Create directory to store processed files in - process_dir = ( - root_dir / "processed" / child_path.stem - ) # Replace root folder with "processed" + processed_dir = Path("/".join(path_parts)).parent / file_name - # Save raw XML metadata here + # Folder for raw XML metadata raw_xml_dir = file.parent / "metadata" - # Create new folders if not already present - for folder in [process_dir, raw_xml_dir]: + # Create folders if not already present + for folder in [processed_dir, raw_xml_dir]: if not folder.exists(): folder.mkdir(parents=True) logger.info(f"Created {folder}") @@ -288,7 +443,7 @@ def convert_lif_to_tiff( logger.info("Extracting image metadata") xml_root = get_xml_metadata( file=lif_file, - save_xml=raw_xml_dir.joinpath(file.stem + ".xml"), + save_xml=raw_xml_dir.joinpath(file_name + ".xml"), ) # Recursively generate list of metadata-containing elements @@ -296,128 +451,31 @@ def convert_lif_to_tiff( # Check that elements match number of images if not len(elem_list) == len(scene_list): - raise Exception( + raise IndexError( "Error matching metadata list to list of sub-images. \n" # Show what went wrong f"Metadata entries: {len(elem_list)} \n" f"Sub-images: {len(scene_list)}" ) - else: - pass # Carry on # Iterate through scenes logger.info("Examining sub-images") - for i in range(len(scene_list)): - # Load image - img = lif_file.get_image(i) # Set sub-image - - # Get name of sub-image - elem = elem_list[i] # Select corresponding element - img_name = elem.attrib["Name"] # Get sub-image name - logger.info(f"Examining {img_name}") - - # Load relevant metadata (channels, dimensions, timestamps etc.) - channels = elem.findall( - "Data/Image/ImageDescription/Channels/ChannelDescription" + # Set up multiprocessing arguments + pool_args = [] + for i in range(len(scene_list)): + pool_args.append( + # Arguments need to be pickle-able; no complex objects allowed + [ # Follow order of args in the function + file, # Reload as LifFile object in the process + i, + elem_list[i], # Corresponding metadata + processed_dir, + ] ) - # Might be useful in the future - # timestamps = elem.find("Data/Image/TimeStampList") - # dimensions = elem.findall( - # "Data/Image/ImageDescription/Dimensions/DimensionDescription" - # ) - - # Create save dirs for TIFF files and their metadata - img_dir = process_dir / img_name - img_xml_dir = img_dir / "metadata" - for folder in [img_dir, img_xml_dir]: - if not folder.exists(): - folder.mkdir(parents=True) - logger.info(f"Created {folder}") - else: - logger.info(f"{folder} already exists") - - # Parijat wants the images in 8-bit; scale down from 16-bit - # Save channels as individual TIFFs - for c in range(len(list(img.get_iter_c()))): - # Get color - color = channels[c].attrib["LUTName"] - logger.info(f"Examining the {color.lower()} channel") - - # Extract image data to array - logger.info("Loading image stack") - arr: np.ndarray = [] # Array to store frames in - # Iterate over slices - for z in range(len(list(img.get_iter_z()))): - frame = img.get_frame(z=z, t=0, c=c) # PIL object; array-like - arr.append(frame) - arr = np.array(arr) # Make independent copy of this array - - # Initial rescaling if bit depth not 8, 16, 32, or 64-bit - bit_depth = img.bit_depth[c] # Initial bit depth - if not any(bit_depth == b for b in [8, 16, 32, 64]): - logger.info("Bit depth non-standard, converting to 16-bit") - arr, bit_depth = rescale_to_bit_depth( - array=arr, initial_bit_depth=bit_depth, target_bit_depth=16 - ) - else: - pass - - # Rescale intensity values for fluorescent channels - if any( - color.lower() in key for key in ["red", "green"] - ): # Eliminate case-sensitivity - logger.info(f"Rescaling {color.lower()} channel across channel depth") - arr = rescale_across_channel( - array=arr, - bit_depth=bit_depth, - percentile_range=(0.5, 99.5), - round_to=16, - ) - - # Convert to 8-bit - logger.info("Converting to 8-bit image") - arr, bit_depth = rescale_to_bit_depth( - arr, initial_bit_depth=bit_depth, target_bit_depth=8 - ) - - # Get x, y, and z scales - # Get resolution (pixels per um) - x_res = img.scale[0] - y_res = img.scale[1] - # Might be used in future versions - # Get pixel size (um per pixel) - # x_scale = 1 / x_res - # y_scale = 1 / y_res + # Parallel process image stacks + with mp.Pool(processes=num_procs) as pool: + result = pool.starmap(process_image_stack, pool_args) - # Check that depth axis exists - if not img.scale[2]: - z_res: float = 0 - z_scale: float = 0 # Avoid divide by zero errors - else: - z_res = img.scale[2] # Pixels per um - z_scale = 1 / z_res # um per pixel - - # Generate slice labels - image_labels = [f"{f}" for f in range(len(list(img.get_iter_z())))] - - # Save as a greyscale TIFF - save_name = img_dir.joinpath(color + ".tiff") - logger.info(f"Saving {color.lower()} image as {save_name}") - imwrite( - save_name, - arr, - imagej=True, # ImageJ comppatible - photometric="minisblack", # Grayscale image - shape=np.shape(arr), - dtype=arr.dtype, - resolution=(x_res * 10**6 / 10**6, y_res * 10**6 / 10**6), - metadata={ - "spacing": z_scale, - "unit": "micron", - "axes": "ZYX", - "Labels": image_labels, - }, - ) - return None + return result