diff --git a/.github/workflows/build_without_artifacts.yml b/.github/workflows/build_without_artifacts.yml index 7e69159..9aaec13 100644 --- a/.github/workflows/build_without_artifacts.yml +++ b/.github/workflows/build_without_artifacts.yml @@ -23,7 +23,9 @@ jobs: - uses: actions/checkout@v2 - name: Install Nix - uses: cachix/install-nix-action@v8 + uses: cachix/install-nix-action@v12 + with: + nix_path: nixpkgs=channel:nixos-20.09 # Runs a set of commands using the runners shell - name: Build application shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd00d7c..59b1fd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: hooks: - id: isort name: Sort imports - always_run: true + always_run: false args: [--multi-line=3, --trailing-comma, --force-grid-wrap=0, --use-parentheses, --line-width=99] ##################################### diff --git a/poretitioner/utils/NTERs_trained_cnn_05152019.py b/poretitioner/utils/NTERs_trained_cnn_05152019.py index d6bef89..9b20e71 100644 --- a/poretitioner/utils/NTERs_trained_cnn_05152019.py +++ b/poretitioner/utils/NTERs_trained_cnn_05152019.py @@ -6,50 +6,57 @@ class CNN(nn.Module): + def __init__(self): + self.O_1 = 17 + self.O_2 = 18 + self.O_3 = 32 + self.O_4 = 37 - O_1 = 17 - O_2 = 18 - O_3 = 32 - O_4 = 37 + self.K_1 = 3 + self.K_2 = 1 + self.K_3 = 4 + self.K_4 = 2 - K_1 = 3 - K_2 = 1 - K_3 = 4 - K_4 = 2 + self.KP_1 = 4 + self.KP_2 = 4 + self.KP_3 = 1 + self.KP_4 = 1 - KP_1 = 4 - KP_2 = 4 - KP_3 = 1 - KP_4 = 1 + reshape = 141 - reshape = 141 - conv_linear_out = int( - m.floor( - ( - m.floor( - ( - m.floor( - ( - m.floor((m.floor((reshape - K_1 + 1) / KP_1) - K_2 + 1) / KP_2) - - K_3 - + 1 + self.conv_linear_out = int( + m.floor( + ( + m.floor( + ( + m.floor( + ( + m.floor( + ( + m.floor((reshape - self.K_1 + 1) / self.KP_1) + - self.K_2 + + 1 + ) + / self.KP_2 + ) + - self.K_3 + + 1 + ) + / self.KP_3 ) - / KP_3 + - self.K_4 + + 1 ) - - K_4 - + 1 + / self.KP_4 ) - / KP_4 + ** 2 ) - ** 2 + * self.O_4 ) - * O_4 ) - ) - FN_1 = 148 + self.FN_1 = 148 - def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( @@ -79,7 +86,9 @@ def forward(self, x): return x -def load_cnn(path): +def load_cnn(state_dict_path, device="cpu"): cnn = CNN() - cnn = torch.load(path) + state_dict = torch.load(state_dict_path, map_location=torch.device(device)) + cnn.load_state_dict(state_dict, strict=True) + cnn.eval() return cnn diff --git a/poretitioner/utils/classify.py b/poretitioner/utils/classify.py index 1b55c20..180ac70 100644 --- a/poretitioner/utils/classify.py +++ b/poretitioner/utils/classify.py @@ -6,301 +6,230 @@ This module contains functionality for classifying nanopore captures. """ -import logging import os -import warnings +import re import h5py -import joblib import numpy as np -import pandas as pd import torch import torch.nn as nn -from . import raw_signal_utils -from .NTERs_trained_cnn_05152019 import load_cnn +import poretitioner.utils.NTERs_trained_cnn_05152019 as pretrained_model +from poretitioner import logger +from poretitioner.utils import filter, raw_signal_utils -warnings.filterwarnings( - "ignore" -) # TODO : Why is this here? : https://github.com/uwmisl/poretitioner/issues/48 -use_cuda = True +use_cuda = False # True +# TODO : Don't hardcode use of CUDA : https://github.com/uwmisl/poretitioner/issues/41 -def check_capture_rejection(end_capture, voltage_ends, tol_obs=20): - # TODO : Do we need this? Should be handled by specifying the filter or segmentation result to be used. - for voltage_end in voltage_ends: - if np.abs(end_capture - voltage_end) < tol_obs: - return True - return False +def filter_and_classify(config, fast5_fnames, overwrite=False, filter_name=None): + local_logger = logger.getLogger() + clf_config = config["classify"] + classifier_name = clf_config["classifier"] + classifier_path = clf_config["classifier_path"] + # Load classifier + local_logger.info(f"Loading classifier {classifier_name}.") + assert classifier_name in ["NTER_cnn", "NTER_rf"] + assert classifier_path is not None and len(classifier_path) > 0 + classifier = init_classifier(classifier_name, classifier_path) -def get_num_classes(classifier, classifier_name): - if classifier_name == "NTER_cnn": - return classifier.fc2.out_features - elif classifier_name == "NTER_rf": - return len(classifier.classes_) + # Filter (optional) + if filter_name is not None: + local_logger.info("Beginning filtering.") + filter.filter_and_store_result(config, fast5_fnames, filter_name, overwrite=overwrite) + read_path = f"/Filter/{filter_name}/pass" else: - return + read_path = "/" + + # Classify + for fast5_fname in fast5_fnames: + with h5py.File(fast5_fname, "r+") as f5: + classify_fast5_file(f5, clf_config, classifier, classifier_name, read_path) + + +def classify_fast5_file( + f5, clf_config, classifier, classifier_run_name, read_path, class_labels=None +): + local_logger = logger.getLogger() + local_logger.debug(f"Beginning classification for file {f5.filename}.") + classifier_name = clf_config["classifier"] + classify_start = clf_config["start_obs"] # 100 in NTER paper + classify_end = clf_config["end_obs"] # 21000 in NTER paper + classifier_conf = clf_config["min_confidence"] + + assert classify_start >= 0 and classify_end >= 0 + assert classifier_conf is None or (0 <= classifier_conf and classifier_conf <= 1) + + local_logger.debug( + f"Classification parameters: name: {classifier_name}, " + f"range of data points: ({classify_start}, {classify_end})" + f"confidence required to pass: {classifier_conf}" + ) + + results_path = f"/Classification/{classifier_run_name}" + write_classifier_details(f5, clf_config, results_path) + + read_h5group_names = f5.get(read_path) + for grp in read_h5group_names: + if "read" not in grp: + continue + read_id = re.findall(r"read_(.*)", str(grp))[0] + signal = raw_signal_utils.get_fractional_blockage_for_read( + f5, grp, start=classify_start, end=classify_end + ) + y, p = predict_class(classifier_name, classifier, signal, class_labels=class_labels) + if classifier_conf is not None: + passed_classification = False if p <= classifier_conf else True + else: + passed_classification = None + write_classifier_result(f5, results_path, read_id, y, p, passed_classification) -# Possible classifier names: NTER_cnn, NTER_rf -# Prediction classes are 1-9: -# 0:Y00, 1:Y01, 2:Y02, 3:Y03, 4:Y04, 5:Y05, 6:Y06, 7:Y07, 8:Y08, 9:noise, -# -1:below conf_thesh def init_classifier(classifier_name, classifier_path): - if classifier_name == "NTER_cnn": - # CNN classifier - nanoporeTER_cnn = load_cnn(classifier_path) - nanoporeTER_cnn.eval() + """Initialize the classification model. Supported classifier names include + "NTER_cnn" and "NTER_rf". + + According to documentation for original NTER code: + Prediction classes are 1-9: + 0:Y00, 1:Y01, 2:Y02, 3:Y03, 4:Y04, 5:Y05, 6:Y06, 7:Y07, 8:Y08, 9:noise, + -1:below conf_thesh + + Parameters + ---------- + classifier_name : str + The name of any supported classifier, currently "NTER_cnn" and "NTER_rf". + classifier_path : str + Location of the pre-trained model file. + + Returns + ------- + model + Classification model (type depends on the spceified model). + + Raises + ------ + ValueError + Raised if the classifier name is not supported. + """ + if classifier_name == "NTER_cnn": # CNN classifier + if not os.path.exists(classifier_path): + raise OSError(f"Classifier path doesn't exist: {classifier_path}") + nanoporeTER_cnn = pretrained_model.load_cnn(classifier_path) return nanoporeTER_cnn - elif classifier_name == "NTER_rf": - # Random Forest classifier + elif classifier_name == "NTER_rf": # Random forest classifier + if not os.path.exists(classifier_path): + raise OSError(f"Classifier path doesn't exist: {classifier_path}") # TODO : Improve model maintainability : https://github.com/uwmisl/poretitioner/issues/38 - return joblib.load(open(classifier_path, "rb")) + # return joblib.load(open(classifier_path, "rb")) + pass else: - raise Exception("Invalid classifier name") - - -# Possible filter names: "NTER_general" -def get_filter_param(filter_name): - # TODO deprecate - # What filter param each value in the output array represents: - # [mean_low, mean_high, stdv_high, med_low, med_high, min_low, min_high, - # max_low, max_high, length, fname_ext] - if filter_name == "NTER_general": - return [0, 0.45, 1, 0.15, 1, 0.005, 1, 0, 0.65, 20100, ""] - else: - raise Exception("Invalid filter name") - - -def print_param(filter_param): - # TODO deprecate - s = "" - s += "Mean: " + str((filter_param[0], filter_param[1])) + "\n" - s += "Stdv: " + str((0, filter_param[2])) + "\n" - s += "Median: " + str((filter_param[3], filter_param[4])) + "\n" - s += "Min: " + str((filter_param[5], filter_param[6])) + "\n" - s += "Max: " + str((filter_param[7], filter_param[8])) + "\n" - s += "Length: " + str(filter_param[9]) + "\n" - return s - - -# Returns -1 if classification probability is below confidence threshold -def classifier_predict(classifier, raw, conf_thresh, classifier_name): + raise ValueError(f"Invalid classifier name: {classifier_name}") + + +def predict_class(classifier_name, classifier, raw, class_labels=None): + """Runs the classifier using the given raw data as input. Does not apply + any kind of confidence threshold. + + Parameters + ---------- + classifier_name : str + The name of any supported classifier, currently "NTER_cnn" and "NTER_rf". + classifier : model + Classification model returned by init_classifier. + raw : iterable of floats + Time series of nanopore current values (in units of fractionalized current). + Returns + ------- + int or string + Class label + float + Model score (for NTER_cnn and NTER_rf, it's a probability) + + Raises + ------ + NotImplementedError + Raised if the input classifier_name is not supported. + """ if classifier_name == "NTER_cnn": X_test = np.array([raw]) - # go from 2D to 3D array (each obs in a capture becomes its own array) + # 2D --> 3D array (each obs in a capture becomes its own array) X_test = X_test.reshape(len(X_test), X_test.shape[1], 1) - X_test = X_test[:, :19881] # take only first 19881 obs of each capture - # break all obs in a captures into 141 groups of 141 (19881 total); each - # capture becomes its own array + if X_test.shape[1] < 19881: + temp = np.zeros((X_test.shape[0], 19881, 1)) + temp[:, : X_test.shape[1], :] = X_test + X_test = temp + X_test = X_test[:, :19881] # First 19881 obs as per NTER paper + # Break capture into 141x141 (19881 total data points) X_test = X_test.reshape(len(X_test), 1, 141, 141) X_test = torch.from_numpy(X_test) - X_test = X_test.cuda() + if use_cuda: + X_test = X_test.cuda() outputs = classifier(X_test) - out = nn.functional.softmax(outputs) + out = nn.functional.softmax(outputs, dim=1) prob, lab = torch.topk(out, 1) - if prob < conf_thresh: - return -1 - lab = lab.cpu().numpy() - return lab[0][0] - else: + if use_cuda: + lab = lab.cpu().numpy()[0][0] + else: + lab = lab.numpy()[0][0] + if class_labels is not None: + lab = class_labels[lab] + prob = prob[0][0].data + return lab, prob + elif classifier_name == "NTER_rf": class_proba = classifier.predict_proba( [[np.mean(raw), np.std(raw), np.min(raw), np.max(raw), np.median(raw)]] )[0] max_proba = np.amax(class_proba) - if max_proba >= conf_thresh: - return np.where(class_proba == max_proba)[0][0] - return -1 - - -# date is a string -# runs is a list of strings, i.e. ["run01_a", "run01_b"] -# filter_name can only be "NTER_general" until more types of filters as added -# classifier_name can be "NTER_cnn" or "NTER_rf" for CNN and Random Forest -# classifiers respectively -# conf_thresh is confidence threshold for classifiers; only classifications >= -# conf_thresh will be written to file -# custom_fname is a custom string to be added to file name -# rej_check ensures that captures which are ejected prematurely are not counted -def filter_and_classify_peptides( - runs, - date, - filter_name, - classifier_name="", - conf_thresh=0.95, - custom_fname="", - rej_check=True, - f5_dir="", - classifier_path="", - capture_fname="", - raw_fname="", - save_dir=".", -): - # TODO : implement fast5 I/O https://github.com/uwmisl/poretitioner/issues/39 - - logger = logging.getLogger("filter_and_classify_peptides") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) - - filter_param = get_filter_param(filter_name) - if custom_fname: - filter_param[10] = filter_param[10] + custom_fname + "_" - - logger.info("Params for " + filter_name + " Filter:") - logger.info(print_param(filter_param)) - - if classifier_name: - classifier = init_classifier(classifier_name, classifier_path) - logger.info("Confidence Threshold: " + str(conf_thresh)) - - all_filtered_files = [] - - for run in runs: - logger.info("Starting run chunk " + run) - - # Prep filenames - capture_file = capture_fname % run - raw_file = raw_fname % run - - # TODO parameterize - f5_file = os.path.join(f5_dir, [x for x in os.listdir(f5_dir) if run in x][0]) - logger.debug("f5_file:" + f5_file) - logger.debug("raw_file:" + raw_file) - logger.debug("capture_file:" + capture_file) - - # Read data into variables - capture_meta_df = pd.read_pickle(capture_file) - raw_captures = np.load(raw_file, allow_pickle=True) - f5 = h5py.File(f5_file, "r") - - # Get the voltage & where it switches - voltage = f5.get("/Device/MetaData").value["bias_voltage"] * 5.0 - voltage_changes = raw_signal_utils.find_segments_below_threshold(voltage, -180) - voltage_ends = [x[1] for x in voltage_changes] - - # Apply length filter - capture_meta_df = capture_meta_df[capture_meta_df.duration_obs > filter_param[9]] - - # Apply 5 feature filters and classify - if classifier_name: - captures = [[] for x in range(0, get_num_classes(classifier, classifier_name))] - else: - captures = [[]] - non_filtered = 0 - non_classified = 0 - for i in capture_meta_df.index: - # To keep track of filter progress - if i % 100 == 0 and i != 0: - logger.debug(str(i)) - - meta_i = capture_meta_df.loc[i, :] - - # If capture is ejected early, don't count it - if rej_check: - capture_rejected = check_capture_rejection(meta_i.end_obs, voltage_ends) - if not capture_rejected: - continue - - raw_minus_10 = raw_captures[i][10:] # skip first 10 obs of capture - - new_mean = np.mean(raw_minus_10) - new_med = np.median(raw_minus_10) - new_min = np.min(raw_minus_10) - new_max = np.max(raw_minus_10) - new_stdv = np.std(raw_minus_10) - - capture = [ - i, - meta_i["run"], - meta_i["channel"], - meta_i["start_obs"], - meta_i["end_obs"], - meta_i["duration_obs"], - ] - - if ( - new_mean > filter_param[0] - and new_mean < filter_param[1] - and new_stdv < filter_param[2] - and new_med > filter_param[3] - and new_med < filter_param[4] - and new_min > filter_param[5] - and new_min < filter_param[6] - and new_max > filter_param[7] - and new_max < filter_param[8] - ): - meta_i["mean"] = new_mean - meta_i["median"] = new_med - meta_i["min"] = new_min - meta_i["max"] = new_max - capture.extend( - [new_mean, new_stdv, new_med, new_min, new_max, meta_i["open_channel"]] - ) - - if classifier_name: - # classifier uses obs 100-20100 of capture - raw_100_to_20100 = raw_captures[i][100:20100] - class_predict = classifier_predict( - classifier, raw_100_to_20100, conf_thresh, classifier_name - ) - if class_predict == -1: - non_classified += 1 - else: - captures[class_predict].append(capture) - else: - captures[0].append(capture) - else: - non_filtered += 1 - - no_pass = float(non_filtered) / len(capture_meta_df.index) * 100 - logger.info("Summary:") - logger.info("Did not pass filter: %0.3f %%" % no_pass) - if classifier_name: - semi_pass = float(non_classified) / len(capture_meta_df.index) * 100 - logger.info("Passed filter but not classifier: %0.3f %%" % semi_pass) - - # Save filtered captures. If classifier was enabled, each class is a - # different file. - for i, class_captures in enumerate(captures): - if class_captures: - filtered_captures = pd.DataFrame(class_captures) - filtered_captures.index = filtered_captures[0] - del filtered_captures[0] - filtered_captures.columns = capture_meta_df.columns - - if "cnn" in classifier_name: - filtered_fname = "%s_segmented_peptides_filtered%s_cnn_class%02d_%s.csv" % ( - date, - filter_param[10], - i, - run, - ) - filtered_fname = os.path.join(save_dir, filtered_fname) - - elif "rf" in classifier_name: - filtered_fname = "%s_segmented_peptides_filtered%s_rf_class%02d_%s.csv" % ( - date, - filter_param[10], - i, - run, - ) - filtered_fname = os.path.join(save_dir, filtered_fname) - else: - filtered_fname = "%s_segmented_peptides_filtered%s_%s" % ( - date, - filter_param[10], - run, - ) - filtered_fname = os.path.join(save_dir, filtered_fname) - logger.info("Saving to " + filtered_fname) - filtered_captures.to_csv(filtered_fname, sep="\t", index=True) - all_filtered_files.append(filtered_fname) - - del captures - f5.close() - torch.cuda.empty_cache() - - return all_filtered_files + lab = np.where(class_proba == max_proba)[0][0] + if class_labels is not None: + lab = class_labels[lab] + return lab, class_proba + else: + raise NotImplementedError(f"Classifier {classifier_name} not implemented.") + + +def get_classification_for_read(f5, read_id, results_path): + results_path = f"{results_path}/{read_id}" + try: + assert results_path in f5 + except AssertionError: + raise ValueError( + f"Read {read_id} has not been classified yet, or result" + f"is not stored at {results_path} in file {f5.filename}." + ) + pred_class = f5[results_path].attrs["best_class"] + prob = f5[results_path].attrs["best_score"] + assigned_class = f5[results_path].attrs["assigned_class"] + passed_classification = True if assigned_class == pred_class else False + return pred_class, prob, assigned_class, passed_classification + + +def write_classifier_details(f5, classifier_config, results_path): + """Write metadata about the classifier that doesn't need to be repeated for + each read. + + Parameters + ---------- + f5 : h5py.File + Opened fast5 file in a writable mode. + classifier_config : dict + Subset of the configuration parameters that belong to the classifier. + results_path : str + Where the classification results will be stored in the f5 file. + """ + if results_path not in f5: + f5.create_group(results_path) + f5[results_path].attrs["model"] = classifier_config["classifier"] + f5[results_path].attrs["model_version"] = classifier_config["classifier_version"] + f5[results_path].attrs["model_file"] = classifier_config["classifier_path"] + f5[results_path].attrs["classification_threshold"] = classifier_config["min_confidence"] + + +def write_classifier_result(f5, results_path, read_id, pred_class, prob, passed_classification): + results_path = f"{results_path}/{read_id}" + if results_path not in f5: + f5.create_group(results_path) + f5[results_path].attrs["best_class"] = pred_class + f5[results_path].attrs["best_score"] = prob + f5[results_path].attrs["assigned_class"] = pred_class if passed_classification else -1 diff --git a/poretitioner/utils/filter.py b/poretitioner/utils/filter.py index 6f6a255..f256277 100644 --- a/poretitioner/utils/filter.py +++ b/poretitioner/utils/filter.py @@ -6,3 +6,202 @@ # TODO : Write functionality for filtering nanopore captures : https://github.com/uwmisl/poretitioner/issues/43 """ +import re + +import h5py +import numpy as np + +from poretitioner import logger + +from . import raw_signal_utils + + +def apply_feature_filters(signal, filters): + """ + Check whether an array of current values (i.e. a single nanopore capture) + passes a set of filters. Filters are based on summary statistics + (e.g., mean) and a range of allowed values. + + Notes on filter behavior: If the filters dict is empty, there are no filters + and the capture passes. Filters are inclusive of high and low values. Only + supported filters are allowed. (mean, stdv, median, min, max, length) + + More complex filtering should be done with a custom function. + + TODO : Move filtering to its own module : (somewhat related: https://github.com/uwmisl/poretitioner/issues/43) + + Parameters + ---------- + signal : array or list + Time series of nanopore current values for a single capture. + filters : dict + Keys are strings matching the supported filters, values are a tuple + giving the endpoints of the valid range. E.g. {"mean": (0.1, 0.5)} + defines a filter such that 0.1 <= mean(capture) <= 0.5. + + Returns + ------- + boolean + True if capture passes all filters; False otherwise. + """ + local_logger = logger.getLogger() + # TODO: Implement logger best practices : https://github.com/uwmisl/poretitioner/issues/12 + supported_filters = { + "mean": np.mean, + "stdv": np.std, + "median": np.median, + "min": np.min, + "max": np.max, + "length": len, + } + other_filters = ["ejected"] + pass_filters = True + for filt, filt_vals in filters.items(): + if filt in supported_filters: + low, high = filt_vals + val = supported_filters[filt](signal) + if (low is not None and low > val) or (high is not None and val > high): + pass_filters = False + return pass_filters + elif filt in other_filters: + continue + else: + local_logger.warning(f"Filter {filt} not supported; ignoring.") + return pass_filters + + +def check_capture_ejection_by_read(f5, read_id): + """Checks whether the current capture was in the pore until the voltage + was reversed. + + Parameters + ---------- + f5 : TODO + read_id : TODO + + Returns + ------- + boolean + True if the end of the capture coincides with the end of a voltage window. + """ + try: + ejected = f5.get(f"/read_{read_id}/Signal").attrs["ejected"] + except AttributeError: + raise ValueError(f"path /read_{read_id} does not exist in the fast5 file.") + return ejected + + +def check_capture_ejection(end_capture, voltage_ends, tol_obs=20): + """Checks whether the current capture was in the pore until the voltage + was reversed. + + Essentially checks whether a value (end_capture) is close enough (within + a margin of tol_obs) to any value in voltage_ends. + + Parameters + ---------- + end_capture : numeric + The end time of the capture. + voltage_ends : list of numeric + List of times when the standard voltage ends. + tol_obs : int, optional + Tolerance for defining when the end of the capture = voltage end, by default 20 + + Returns + ------- + boolean + True if the end of the capture coincides with the end of a voltage window. + """ + for voltage_end in voltage_ends: + if np.abs(end_capture - voltage_end) < tol_obs: + return True + return False + + +def apply_filters_to_read(config, f5, read_id, filter_name): + passed_filters = True + + # Check whether the capture was ejected + if "ejected" in config["filters"][filter_name]: + only_use_ejected_captures = config["filters"][filter_name]["ejected"] # TODO + if only_use_ejected_captures: + capture_ejected = check_capture_ejection_by_read(f5, read_id) + if not capture_ejected: + passed_filters = False + return passed_filters + else: + only_use_ejected_captures = False # could skip this, leaving to help read logic + + # Apply all the filters + signal = raw_signal_utils.get_fractional_blockage_for_read(f5, read_id) + # print(config["filters"][filter_name]) + # print(f"min = {np.min(signal)}") + passed_filters = apply_feature_filters(signal, config["filters"][filter_name]) + return passed_filters + + +def filter_and_store_result(config, fast5_files, filter_name, overwrite=False): + # Apply a new set of filters + # Write filter results to fast5 file (using format) + # if only_use_ejected_captures = config["???"]["???"], then check_capture_ejection + # if there's a min length specified in the config, use that + # if feature filters are specified, apply those + # save all filter parameters in the filter_name path + filter_path = f"/Filter/{filter_name}" + + # TODO: parallelize this (embarassingly parallel structure) + for fast5_file in fast5_files: + with h5py.File(fast5_file, "a") as f5: + if overwrite is False and filter_path in f5: + continue + passed_read_ids = [] + for read_h5group_name in f5.get("/"): + if "read" not in read_h5group_name: + continue + read_id = re.findall(r"read_(.*)", read_h5group_name)[0] + + passed_filters = apply_filters_to_read(config, f5, read_id, filter_name) + if passed_filters: + passed_read_ids.append(read_id) + write_filter_results(f5, config, passed_read_ids, filter_name) + + +def write_filter_results(f5, config, read_ids, filter_name): + local_logger = logger.getLogger() + filter_path = f"/Filter/{filter_name}" + if filter_path not in f5: + f5.create_group(f"{filter_path}/pass") + filt_grp = f5.get(filter_path) + + # Save filter configuration to the fast5 file at filter_path + for k, v in config.items(): + if k == filter_name: + local_logger.debug("keys and vals:", k, v) + for filt, filt_vals in v.items(): + if len(filt_vals) == 2: + (min_filt, max_filt) = filt_vals + # Create compound dset for filters + local_logger.debug("filt types", type(min_filt), type(max_filt)) + dtypes = np.dtype([("min", type(min_filt), ("max", type(max_filt)))]) + d = filt_grp.create_dataset(k, (2,), dtype=dtypes) + d[filt] = (min_filt, max_filt) + else: + d = filt_grp.create_dataset(k) + d[filt] = filt_vals + + # For all read_ids that passed the filter (AKA reads that were passed in), + # create a hard link in the filter_path to the actual read's location in + # the fast5 file. + for read_id in read_ids: + read_path = f"/read_{read_id}" + read_grp = f5.get(read_path) + local_logger.debug(read_grp) + filter_read_path = f"{filter_path}/pass/read_{read_id}" + # Create a hard link from the filter read path to the actual read path + f5[filter_read_path] = read_grp + + +def filter_like_existing(config, example_fast5, example_filter_path, fast5_files, new_filter_path): + # Filters a set of fast5 files exactly the same as an existing filter + # TODO : #68 : implement + raise NotImplementedError() diff --git a/poretitioner/utils/model/NTERs_trained_cnn_05152019.pt b/poretitioner/utils/model/NTERs_trained_cnn_05152019.pt deleted file mode 100644 index ffe2e0f..0000000 Binary files a/poretitioner/utils/model/NTERs_trained_cnn_05152019.pt and /dev/null differ diff --git a/poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt b/poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt index 36f9f48..6df041c 100644 Binary files a/poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt and b/poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt differ diff --git a/poretitioner/utils/quantify.py b/poretitioner/utils/quantify.py index 17387f0..f475ebd 100644 --- a/poretitioner/utils/quantify.py +++ b/poretitioner/utils/quantify.py @@ -7,7 +7,6 @@ This module contains functionality for quantifying nanopore captures. """ -import logging import os import re @@ -15,6 +14,8 @@ import numpy as np import pandas as pd +from poretitioner import logger + from .raw_signal_utils import find_segments_below_threshold from .yaml_assistant import YAMLAssistant @@ -25,7 +26,7 @@ def get_related_files(input_file, raw_file_dir="", capture_file_dir=""): - """ TODO : Deprecate! : https://github.com/uwmisl/poretitioner/issues/40 + """TODO : Deprecate! : https://github.com/uwmisl/poretitioner/issues/40 Find files matching the input file in the data directory tree. @@ -46,18 +47,10 @@ def get_related_files(input_file, raw_file_dir="", capture_file_dir=""): Iterable of filenames (strings) The raw file (fast5) and capture file (unfiltered). """ - logger = logging.getLogger("get_related_files") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) - - logger.debug(input_file) - logger.debug(raw_file_dir) - logger.debug(capture_file_dir) + local_logger = logger.getLogger() run_name = re.findall(r"(run\d\d_.*)\..*", input_file)[0] # e.g. "run01_a" - logger.debug(run_name) + local_logger.debug(run_name) assert len(raw_file_dir) > 0 raw_file = [x for x in os.listdir(raw_file_dir) if run_name in x][0] @@ -74,14 +67,14 @@ def get_related_files(input_file, raw_file_dir="", capture_file_dir=""): capture_file = input_file filtered_file = "Unspecified" else: - logger.error("Invalid file name") + local_logger.error("Invalid file name") return - logger.info("Filter File: " + filtered_file) + local_logger.info("Filter File: " + filtered_file) raw_file = os.path.join(raw_file_dir, raw_file) - logger.info("Raw File: " + raw_file) + local_logger.info("Raw File: " + raw_file) capture_file = os.path.join(capture_file_dir, capture_file) - logger.info("Capture File: " + capture_file) + local_logger.info("Capture File: " + capture_file) return raw_file, capture_file @@ -132,37 +125,33 @@ def get_overlapping_regions(window, regions): def calc_time_until_capture(capture_windows, captures, blockages=None): """calc_time_until_capture - Finds all times between captures from a single channel. This is defined - as the open pore time from the end of the previous capture to the - current capture. Includes subtracting other non-capture blockages since - those blockages reduce the amount of overall open pore time. - - Note: called by "get_capture_time" and "get_capture_time_tseg" - - Parameters - ---------- - capture_windows : list of tuples of ints [(start, end), ...] - Regions of current where the nanopore is available to accept a - capture. (I.e., is in a "normal" voltage state.) [(start, end), ...] - captures : list of tuples of ints [(start, end), ...] - Regions of current where a capture is residing in the pore. The - function is calculating time between these values (minus blockages). - blockages : list of tuples of ints [(start, end), ...] - Regions of current where the pore is blocked by any capture or non- - capture. These are removed from the time between captures, if - specified. - - Returns - ------- - list of floats - List of all capture times from a single channel. - """ + Finds all times between captures from a single channel. This is defined + as the open pore time from the end of the previous capture to the + current capture. Includes subtracting other non-capture blockages since + those blockages reduce the amount of overall open pore time. + + Note: called by "get_capture_time" and "get_capture_time_tseg" + + Parameters + ---------- + capture_windows : list of tuples of ints [(start, end), ...] + Regions of current where the nanopore is available to accept a + capture. (I.e., is in a "normal" voltage state.) [(start, end), ...] + captures : list of tuples of ints [(start, end), ...] + Regions of current where a capture is residing in the pore. The + function is calculating time between these values (minus blockages). + blockages : list of tuples of ints [(start, end), ...] + Regions of current where the pore is blocked by any capture or non- + capture. These are removed from the time between captures, if + specified. + + Returns + ------- + list of floats + List of all capture times from a single channel. + """ # TODO: Implement logger best practices : https://github.com/uwmisl/poretitioner/issues/12 - logger = logging.getLogger("calc_time_until_capture") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) + # local_logger = logger.getLogger() all_capture_times = [] @@ -210,7 +199,11 @@ def calc_time_until_capture(capture_windows, captures, blockages=None): def get_time_between_captures( - filtered_file, time_interval=None, raw_file_dir="", capture_file_dir="", config_file="" + filtered_file, + time_interval=None, + raw_file_dir="", + capture_file_dir="", + config_file="", ): """Get the average time between captures across all channels. Can be computed for the specified time interval, or across the entire run if not @@ -235,11 +228,7 @@ def get_time_between_captures( [description] """ # TODO: Implement logger best practices : https://github.com/uwmisl/poretitioner/issues/12 - logger = logging.getLogger("get_time_between_captures") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) + local_logger = logger.getLogger("get_time_between_captures") # TODO : Implement capture fast5 I/O : https://github.com/uwmisl/poretitioner/issues/40 @@ -382,7 +371,7 @@ def get_time_between_captures( captures_count.append(len(capture_times)) checkpoint = end_voltage_seg else: - logger.warn( + local_logger.warn( "No open voltage region in time segment [" + str(checkpoint) + ", " @@ -392,7 +381,7 @@ def get_time_between_captures( timepoint_captures.append(-1) checkpoint = timepoint - logger.info("Number of Captures: " + str(captures_count)) + local_logger.info("Number of Captures: " + str(captures_count)) return timepoint_captures @@ -403,15 +392,14 @@ def get_time_between_captures( def get_capture_freq( - filtered_file, time_interval=None, raw_file_dir="", capture_file_dir="", config_file="" + filtered_file, + time_interval=None, + raw_file_dir="", + capture_file_dir="", + config_file="", ): # TODO: Implement logger best practices : https://github.com/uwmisl/poretitioner/issues/12 - logger = logging.getLogger("get_capture_freq") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) - + local_logger = logger.getLogger() # TODO : Implement capture fast5 I/O : https://github.com/uwmisl/poretitioner/issues/40 # Retrieve raw file and config file names @@ -432,7 +420,7 @@ def get_capture_freq( good_channels = y.get_variable("fast5:good_channels:" + filtered_file[-11:-4]) for i in range(0, len(good_channels)): good_channels[i] = "Channel_" + str(good_channels[i]) - logger.info("Number of Channels: " + str(len(good_channels))) + local_logger.info("Number of Channels: " + str(len(good_channels))) # Process filtered captures file captures = pd.read_csv(filtered_file, index_col=0, header=0, sep="\t") @@ -486,7 +474,7 @@ def get_capture_freq( all_capture_freq.append(np.mean(capture_counts) / (time_segments[0] / 600_000.0)) checkpoint = end_voltage_seg else: - logger.warn( + local_logger.warn( "No open voltage region in time segment [" + str(checkpoint) + ", " diff --git a/poretitioner/utils/raw_signal_utils.py b/poretitioner/utils/raw_signal_utils.py index f1f8fa1..f32a7db 100644 --- a/poretitioner/utils/raw_signal_utils.py +++ b/poretitioner/utils/raw_signal_utils.py @@ -94,15 +94,18 @@ def get_fractional_blockage( def get_local_fractional_blockage( - f5, open_channel_guess=220, open_channel_bound=15, channel=None, local_window_sz=1000 + f5, + open_channel_guess=220, + open_channel_bound=15, + channel=None, + local_window_sz=1000, ): """Retrieve the scaled raw signal for the channel, compute the open pore current, and return the fractional blockage for that channel.""" signal = get_scaled_raw_for_channel(f5, channel=channel) open_channel = find_open_channel_current(signal, open_channel_guess, bound=open_channel_bound) if open_channel is None: - print("open pore is None") - + # print("open pore is None") return None frac = np.zeros(len(signal)) @@ -157,6 +160,110 @@ def get_sampling_rate(f5): return sample_rate +def get_fractional_blockage_for_read(f5, read_id, start=None, end=None): + """Retrieve the scaled raw signal for the specified read. + + Parameters + ---------- + f5 : h5py.File + Fast5 file, open for reading using h5py.File. + read_id : str + read_id to retrieve fractionalized current. + + Returns + ------- + Numpy array (float) + Fractionalized current from the specified read_id. + """ + signal = get_scaled_raw_for_read(f5, read_id, start=start, end=end) + if "read" in read_id: + channel_path = f"{read_id}/channel_id" + else: + channel_path = f"read_{read_id}/channel_id" + open_channel = f5.get(channel_path).attrs["open_channel_pA"] + frac = compute_fractional_blockage(signal, open_channel) + return frac + + +def get_raw_signal_for_read(f5, read_id, start=None, end=None): + """Retrieve raw signal from open fast5 file for the specified read_id. + + Parameters + ---------- + f5 : h5py.File + Fast5 file, open for reading using h5py.File. + read_id : str + Read id to retrieve raw signal. Can be formatted as a path ("read_xxx...") + or just the read id ("xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"). + + Returns + ------- + Numpy array + Array representing sampled nanopore current. + """ + if "read" in read_id: + signal_path = f"/{read_id}/Signal" + else: + signal_path = f"/read_{read_id}/Signal" + if signal_path in f5: + raw = f5.get(signal_path)[start:end] + return raw + else: + raise ValueError(f"Path {signal_path} not in fast5 file.") + + +def get_scaled_raw_for_read(f5, read_id, start=None, end=None): + """Retrieve raw signal from open fast5 file, scaled to pA units. + + Note: using UK sp. of digitization for consistency w/ file format + + Parameters + ---------- + f5 : h5py.File + Fast5 file, open for reading using h5py.File. + read_id : str + Read id to retrieve raw signal. Can be formatted as a path ("read_xxx...") + or just the read id ("xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"). + + Returns + ------- + Numpy array + Array representing sampled nanopore current, scaled to pA. + """ + raw = get_raw_signal_for_read(f5, read_id, start=start, end=end) + offset, rng, digi = get_scale_metadata_for_read(f5, read_id) + return scale_raw_current(raw, offset, rng, digi) + + +def get_scale_metadata_for_read(f5, read_id): + """Retrieve scaling values for a specific read in a segmented fast5 file. + + Note: using UK sp. of digitization for consistency w/ file format + + Parameters + ---------- + f5 : h5py.File + Fast5 file, open for reading using h5py.File. + read_id : str + Read id to retrieve raw signal. Can be formatted as a path ("read_xxx...") + or just the read id ("xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"). + + Returns + ------- + Tuple + Offset, range, and digitisation values. + """ + if "read" in read_id: + channel_path = f"{read_id}/channel_id" + else: + channel_path = f"read_{read_id}/channel_id" + attrs = f5.get(channel_path).attrs + offset = attrs.get("offset") + rng = attrs.get("range") + digi = attrs.get("digitisation") + return offset, rng, digi + + def get_raw_signal(f5, channel_no, start=None, end=None): """Retrieve raw signal from open fast5 file. diff --git a/poretitioner/utils/sample.py b/poretitioner/utils/sample.py new file mode 100644 index 0000000..1951946 --- /dev/null +++ b/poretitioner/utils/sample.py @@ -0,0 +1,9 @@ +a = [6, 7, 8, 9] + +# black wants no space before : but also wants space before : +ix = (2, 3) +b = a[ix[0] : ix[1]] + +# black is happy +i, j = (2, 3) +b = a[i:j] diff --git a/poretitioner/utils/segment.py b/poretitioner/utils/segment.py index 38265c3..6a25afa 100644 --- a/poretitioner/utils/segment.py +++ b/poretitioner/utils/segment.py @@ -7,7 +7,6 @@ bulk fast5s. """ -import logging import os import uuid @@ -15,9 +14,11 @@ import h5py import numpy as np from dask.diagnostics import ProgressBar + +from poretitioner import logger from poretitioner.application_info import get_application_info -from . import raw_signal_utils +from . import filter, raw_signal_utils ProgressBar().register() @@ -27,60 +28,6 @@ __name__ = app_info.name -def apply_capture_filters(capture, filters): - """ - Check whether a single nanopore capture passes a set of filters. Filters - are based on summary statistics (e.g., mean) and a range of allowed values. - - Notes on filter behavior: If the filters dict is empty, there are no filters - and the capture passes. Filters are inclusive of high and low values. Only - supported filters are allowed. (mean, stdv, median, min, max, length) - - More complex filtering should be done with a custom function. - - TODO : Move filtering to its own module : (somewhat related: https://github.com/uwmisl/poretitioner/issues/43) - - Parameters - ---------- - capture : array or list - Time series of nanopore current values for a single capture. - filters : dict - Keys are strings matching the supported filters, values are a tuple - giving the endpoints of the valid range. E.g. {"mean": (0.1, 0.5)} - defines a filter such that 0.1 <= mean(capture) <= 0.5. - - Returns - ------- - boolean - True if capture passes all filters; False otherwise. - """ - logger = logging.getLogger("apply_capture_filters") - # TODO: Implement logger best practices : https://github.com/uwmisl/poretitioner/issues/12 - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) - supported_filters = { - "mean": np.mean, - "stdv": np.std, - "median": np.median, - "min": np.min, - "max": np.max, - "length": len, - } - pass_filters = True - for filt, (low, high) in filters.items(): - if filt in supported_filters: - val = supported_filters[filt](capture) - if (low is not None and low > val) or (high is not None and val > high): - pass_filters = False - return pass_filters - else: - # Warn filter not supported - logger.warning(f"Filter {filt} not supported; ignoring.") - return pass_filters - - def find_captures( signal_pA, signal_threshold_frac, @@ -142,25 +89,30 @@ def find_captures( del signal_pA # Apply signal threshold & get list of captures captures = raw_signal_utils.find_segments_below_threshold(frac_current, signal_threshold_frac) - # If last_capture_only, reduce list of captures to only last - if terminal_capture_only and len(captures) > 1: - if np.abs(captures[-1][1] - len(frac_current)) <= end_tol: + annotated_captures = [] + for capture in captures: + ejected = np.abs(capture[1] - len(frac_current)) <= end_tol + annotated_captures.append((capture[0], capture[1], ejected)) + captures = annotated_captures[:] + # If terminal_capture_only, reduce list of captures to only last + if terminal_capture_only: + if len(captures) > 1 and captures[-1][2]: captures = [captures[-1]] else: captures = [] if delay > 0: for i, capture in enumerate(captures): - capture_start, capture_end = capture + capture_start, capture_end, ejected = capture if capture_end - capture_start > delay: capture_start += delay - captures[i] = (capture_start, capture_end) + captures[i] = (capture_start, capture_end, capture[2]) # Apply filters to remaining capture(s) filtered_captures = [] for capture in captures: - capture_start, capture_end = capture - if apply_capture_filters(frac_current[capture_start:capture_end], filters): + capture_start, capture_end, ejected = capture + if filter.apply_feature_filters(frac_current[capture_start:capture_end], filters): filtered_captures.append(capture) # Return list of captures return filtered_captures, open_channel_pA @@ -295,17 +247,17 @@ def create_capture_fast5( # config = {"param": "value", # "filters": {"f1": (min, max), "f2: (min, max)"}} g = capture_f5.create_group("/Meta/Segmentation") - print(__name__) + # print(__name__) g.attrs.create("segmenter", __name__, dtype=f"S{len(__name__)}") g.attrs.create("segmenter_version", __version__, dtype=f"S{len(__version__)}") g_filt = capture_f5.create_group("/Meta/Segmentation/filters") for k, v in config.items(): - if k == "filters": + if k == "base filter": for filt, (min_filt, max_filt) in v.items(): # Create compound dset for filters dtypes = np.dtype([("min", type(min_filt), ("max", type(max_filt)))]) d = g_filt.create_dataset(k, (2,), dtype=dtypes) - d[...] = (min_filt, max_filt) + d[filt] = (min_filt, max_filt) else: g.create(k, v) @@ -346,38 +298,36 @@ def _prep_capture_windows( about these segments (channel number, window endpoints, offset, range, digitisation). """ - logger = logging.getLogger("_prep_capture_windows") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) + local_logger = logger.getLogger() with h5py.File(bulk_f5_fname, "r") as bulk_f5: - logger.info(f"Reading in signals for bulk file: {bulk_f5_fname}") + local_logger.info(f"Reading in signals for bulk file: {bulk_f5_fname}") voltage = raw_signal_utils.get_voltage( bulk_f5, start=f5_subsection_start, end=f5_subsection_end ) - logger.debug(f"voltage: {voltage}") + local_logger.debug(f"voltage: {voltage}") run_id = str(bulk_f5["/UniqueGlobalKey/tracking_id"].attrs.get("run_id"))[2:-1] sampling_rate = int(bulk_f5["/UniqueGlobalKey/context_tags"].attrs.get("sample_frequency")) - logger.info("Identifying capture windows (via voltage threshold).") + local_logger.info("Identifying capture windows (via voltage threshold).") capture_windows = raw_signal_utils.find_segments_below_threshold(voltage, voltage_t) - logger.debug( + local_logger.debug( f"run_id: {run_id}, sampling_rate: {sampling_rate}, " f"#/capture windows: {len(capture_windows)}" ) - logger.debug("Prepping raw signals for parallel processing.") + local_logger.debug("Prepping raw signals for parallel processing.") raw_signals = [] # Input data to find_captures_dask_wrapper signal_metadata = [] # Metadata -- no need to pass through the segmenter for channel_no in good_channels: raw = raw_signal_utils.get_scaled_raw_for_channel( - bulk_f5, channel_no=channel_no, start=f5_subsection_start, end=f5_subsection_end + bulk_f5, + channel_no=channel_no, + start=f5_subsection_start, + end=f5_subsection_end, ) offset, rng, digi = raw_signal_utils.get_scale_metadata(bulk_f5, channel_no) for capture_window in capture_windows: - raw_signals.append( - (raw[capture_window[0] : capture_window[1]], signal_t, open_channel_prior_mean) - ) + start, end = capture_window[0], capture_window[1] + raw_signals.append((raw[start:end], signal_t, open_channel_prior_mean)) signal_metadata.append([channel_no, capture_window, offset, rng, digi]) return raw_signals, signal_metadata, run_id, sampling_rate @@ -417,11 +367,7 @@ def parallel_find_captures( the config, raise IOError. """ - logger = logging.getLogger("parallel_find_captures") - if logger.handlers: - logger.handlers = [] - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) + local_logger = logger.getLogger() n_workers = config["compute"]["n_workers"] assert type(n_workers) is int @@ -432,10 +378,10 @@ def parallel_find_captures( good_channels = config["segment"]["good_channels"] end_tol = config["segment"]["end_tol"] terminal_capture_only = config["segment"]["terminal_capture_only"] - filters = config["filters"] + filters = config["filters"]["base filter"] save_location = config["output"][ "capture_f5_dir" - ] # TODO : Verify exists; don't create (handle earlier) + ] # TODO: Verify exists; don't create (handle earlier) n_per_file = config["output"]["captures_per_f5"] if f5_subsection_start is None: f5_subsection_start = 0 @@ -453,7 +399,7 @@ def parallel_find_captures( open_channel_prior_mean, ) - logger.debug("Loading up the bag with signals.") + local_logger.debug("Loading up the bag with signals.") bag = db.from_sequence(raw_signals, npartitions=64) capture_map = bag.map( find_captures_dask_wrapper, @@ -462,10 +408,10 @@ def parallel_find_captures( delay=delay, end_tol=end_tol, ) - logger.info("Beginning segmentation.") + local_logger.info("Beginning segmentation.") captures = capture_map.compute(num_workers=n_workers) assert len(captures) == len(context) - logger.debug(f"Captures (1st 10): {captures[:10]}") + local_logger.debug(f"Captures (1st 10): {captures[:10]}") # Write captures to fast5 n_in_file = 0 @@ -487,14 +433,16 @@ def parallel_find_captures( start_time_local + f5_subsection_start ) # relative to the start of the entire bulk f5 capture_duration = capture[1] - capture[0] - logger.debug(f"Capture duration: {capture_duration}") + ejected = capture[2] + local_logger.debug(f"Capture duration: {capture_duration}") if n_in_file >= n_per_file: n_in_file = 0 file_no += 1 capture_f5_fname = os.path.join(save_location, f"{run_id}_{file_no}.fast5") n_in_file += 1 - raw_pA = window_raw[capture[0] : capture[1]] - logger.debug(f"Length of raw signal : {len(raw_pA)}") + start, end = capture[0], capture[1] + raw_pA = window_raw[start:end] + local_logger.debug(f"Length of raw signal : {len(raw_pA)}") write_capture_to_fast5( capture_f5_fname, read_id, @@ -502,6 +450,7 @@ def parallel_find_captures( start_time_bulk, start_time_local, capture_duration, + ejected, voltage_t, open_channel_pA, channel_no, @@ -517,6 +466,7 @@ def parallel_find_captures( start_time_bulk, start_time_local, capture_duration, + ejected, voltage_t, open_channel_pA, channel_no, @@ -533,6 +483,7 @@ def write_capture_to_fast5( start_time_bulk, start_time_local, duration, + ejected, voltage, open_channel_pA, channel_no, @@ -560,6 +511,8 @@ def write_capture_to_fast5( region. (Relative to f5_subsection_start in parallel_find_captures().) duration : int Number of data points in the capture. + ejected : boolean + Whether or not the capture was ejected from the pore. voltage : int Voltage at which the capture occurs (single value for entire window). open_channel_pA : float @@ -586,6 +539,7 @@ def write_capture_to_fast5( f5[signal_path].attrs["start_time_bulk"] = start_time_bulk f5[signal_path].attrs["start_time_local"] = start_time_local f5[signal_path].attrs["duration"] = duration + f5[signal_path].attrs["ejected"] = ejected f5[signal_path].attrs["voltage"] = voltage f5[signal_path].attrs["open_channel_pA"] = open_channel_pA diff --git a/tests/data/classified_9captures.fast5 b/tests/data/classified_9captures.fast5 new file mode 100644 index 0000000..b60306c Binary files /dev/null and b/tests/data/classified_9captures.fast5 differ diff --git a/tests/data/classifier_details_test.fast5 b/tests/data/classifier_details_test.fast5 new file mode 100644 index 0000000..024d8a2 Binary files /dev/null and b/tests/data/classifier_details_test.fast5 differ diff --git a/tests/data/filter_and_store_result_test.fast5 b/tests/data/filter_and_store_result_test.fast5 new file mode 100644 index 0000000..7f3d383 Binary files /dev/null and b/tests/data/filter_and_store_result_test.fast5 differ diff --git a/tests/data/reads_fast5_dummy_9captures.fast5 b/tests/data/reads_fast5_dummy_9captures.fast5 new file mode 100644 index 0000000..6f748f8 Binary files /dev/null and b/tests/data/reads_fast5_dummy_9captures.fast5 differ diff --git a/tests/test_classify.py b/tests/test_classify.py new file mode 100644 index 0000000..63029f5 --- /dev/null +++ b/tests/test_classify.py @@ -0,0 +1,399 @@ +""" +================ +test_classify.py +================ + +This module contains tests for classify.py functionality. + +""" +import os +import re +from shutil import copyfile + +import h5py +import pytest + +from poretitioner.utils import classify, raw_signal_utils +from poretitioner.utils.NTERs_trained_cnn_05152019 import load_cnn + + +def init_classifier_invalidinput_test(): + clf_name = "invalid" + clf_path = "not_a_path" + with pytest.raises(ValueError): + classify.init_classifier(clf_name, clf_path) + + clf_name = "NTER_cnn" + clf_path = "not_a_path" + with pytest.raises(OSError): + classify.init_classifier(clf_name, clf_path) + + clf_name = "invalid" + clf_path = "../../poretitioner/model/NTERs_trained_cnn_05152019.statedict.pt" + with pytest.raises(ValueError): + classify.init_classifier(clf_name, clf_path) + + +def load_cnn_test(): + clf_path = "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt" + load_cnn(clf_path) + + +def init_classifier_cnn_test(): + clf_name = "NTER_cnn" + clf_path = "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt" + assert os.path.exists(clf_path) + classifier = classify.init_classifier(clf_name, clf_path) + assert classifier + assert classifier.conv1 is not None # check existence of a layer + + +def init_classifier_rf_test(): + pass + + +def predict_class_test(): + # Predict classification result from 9 segment file + # Make assertions about the class & probability + + # Load a classifier + clf_name = "NTER_cnn" + clf_path = "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt" + classifier = classify.init_classifier(clf_name, clf_path) + + correct_results = { + "697de4c1-1aef-41b9-ae0d-d676e983cb7e": {"label": 5, "prob": 0.726}, + "8e8181d2-d749-4735-9cab-37648b463f88": {"label": 8, "prob": 0.808}, + "97d06f2e-90e6-4ca7-91c4-084f13b693f2": {"label": 9, "prob": 0.977}, + "a0c40f5a-c685-43b9-a3b7-ca13aa90d832": {"label": 9, "prob": 0.977}, + "ab54dab5-26a7-4062-9d77-f63fc40f702c": {"label": 9, "prob": 0.979}, + "c87905e6-fd62-4ac6-bcbd-c7f17ff4af14": {"label": 9, "prob": 0.977}, + "cd6fa746-e93b-467f-a3fc-1c9af815f836": {"label": 6, "prob": 0.484}, + "df4365f4-bfe4-4d2c-8100-34c36cd11378": {"label": 9, "prob": 0.975}, + "f5d76520-c92b-4a9c-b5cb-a04414db527e": {"label": 6, "prob": 0.254}, + } + + # Load raw data & classify + f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + with h5py.File(f5_fname, "r") as f5: + for grp in f5.get("/"): + if "read" in str(grp): + read_id = re.findall(r"read_(.*)", str(grp))[0] + + raw = raw_signal_utils.get_fractional_blockage_for_read(f5, read_id) + pred_label, pred_prob = classify.predict_class( + clf_name, classifier, raw[100:], class_labels=None + ) + # print( + # f'"{read_id}": {{"label": {pred_label}, "prob": {pred_prob:0.3f}}},' + # ) + p = pred_prob + # print(f"p: {p} {type(p)}") + + actual_label = correct_results[read_id]["label"] + actual_prob = correct_results[read_id]["prob"] + + assert actual_label == pred_label + assert abs(actual_prob - pred_prob) < 0.02 + + +def write_classifier_details_test(): + # Copy sample file + orig_f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + test_f5_fname = "write_classifier_details_test.fast5" + copyfile(orig_f5_fname, test_f5_fname) + + # Define subset of the config dict that contains filter info + clf_config = { + "classifier": "NTER_cnn", + "classifier_path": "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt", + "classifier_version": "1.0", + "start_obs": 100, + "end_obs": 2100, + "min_confidence": 0.9, + } + + # Call fn + with h5py.File(test_f5_fname, "r+") as f5: + classify.write_classifier_details(f5, clf_config, "/Classification/NTER_cnn") + + # Read back the expected details + with h5py.File(test_f5_fname, "r") as f5: + assert "/Classification" in f5 + assert "/Classification/NTER_cnn" in f5 + attrs = f5.get("/Classification/NTER_cnn").attrs + for key in ["model", "model_version", "model_file", "classification_threshold"]: + assert key in attrs + os.remove(test_f5_fname) + + +def write_classifier_result_test(): + # Copy sample file + orig_f5_fname = "tests/data/classifier_details_test.fast5" + test_f5_fname = "write_classifier_result_test.fast5" + copyfile(orig_f5_fname, test_f5_fname) + + classifier_run_name = "NTER_cnn" + results_path = f"/Classification/{classifier_run_name}" + + # Call fn + pred_class = 9 + prob = 0.977 + passed = True + with h5py.File(test_f5_fname, "r+") as f5: + classify.write_classifier_result( + f5, + results_path, + "c87905e6-fd62-4ac6-bcbd-c7f17ff4af14", + pred_class, + prob, + passed, + ) + + # Read back the expected info + with h5py.File(test_f5_fname, "r") as f5: + assert results_path in f5 + assert f"{results_path}/c87905e6-fd62-4ac6-bcbd-c7f17ff4af14" in f5 + attrs = f5.get(f"{results_path}/c87905e6-fd62-4ac6-bcbd-c7f17ff4af14").attrs + assert attrs["best_class"] == pred_class + assert attrs["best_score"] == prob + assert attrs["assigned_class"] == pred_class + os.remove(test_f5_fname) + + +def get_classification_for_read_test(): + f5_fname = "tests/data/classified_9captures.fast5" + read_id = "c87905e6-fd62-4ac6-bcbd-c7f17ff4af14" + classifier_run_name = "NTER_cnn" + results_path = f"/Classification/{classifier_run_name}" + with h5py.File(f5_fname, "r") as f5: + ( + pred_class, + prob, + assigned_class, + passed_classification, + ) = classify.get_classification_for_read(f5, read_id, results_path) + assert pred_class == 9 + assert prob == 0.9876 + assert assigned_class == 9 + assert passed_classification is True + + +def classify_fast5_file_unfiltered_test(): + # Predict classification result from 9 segment file + # Make assertions about the class & probability + + # Define subset of the config dict that contains config info + clf_config = { + "classifier": "NTER_cnn", + "classifier_path": "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt", + "classifier_version": "1.0", + "start_obs": 100, + "end_obs": 21000, + "min_confidence": 0.9, + } + + # Load a classifier + clf_name = "NTER_cnn" + clf_path = "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt" + classifier = classify.init_classifier(clf_name, clf_path) + classifier_run_name = "NTER_cnn" + results_path = f"/Classification/{classifier_run_name}" + + correct_results = { + "697de4c1-1aef-41b9-ae0d-d676e983cb7e": {"label": 5, "prob": 0.726}, + "8e8181d2-d749-4735-9cab-37648b463f88": {"label": 8, "prob": 0.808}, + "97d06f2e-90e6-4ca7-91c4-084f13b693f2": {"label": 9, "prob": 0.977}, + "a0c40f5a-c685-43b9-a3b7-ca13aa90d832": {"label": 9, "prob": 0.977}, + "ab54dab5-26a7-4062-9d77-f63fc40f702c": {"label": 9, "prob": 0.979}, + "c87905e6-fd62-4ac6-bcbd-c7f17ff4af14": {"label": 9, "prob": 0.977}, + "cd6fa746-e93b-467f-a3fc-1c9af815f836": {"label": 6, "prob": 0.484}, + "df4365f4-bfe4-4d2c-8100-34c36cd11378": {"label": 9, "prob": 0.975}, + "f5d76520-c92b-4a9c-b5cb-a04414db527e": {"label": 6, "prob": 0.254}, + } + + # Prepare file for testing + orig_f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + test_f5_fname = "classify_fast5_file_unfiltered_test.fast5" + copyfile(orig_f5_fname, test_f5_fname) + + # Classify f5 file directly + with h5py.File(test_f5_fname, "r+") as f5: + classify.classify_fast5_file(f5, clf_config, classifier, clf_name, "/", class_labels=None) + + # Evaluate output written to file + with h5py.File(test_f5_fname, "r") as f5: + for grp in f5.get("/"): + if "read" in str(grp): + read_id = re.findall(r"read_(.*)", str(grp))[0] + + ( + pred_label, + pred_prob, + assigned_class, + passed_classification, + ) = classify.get_classification_for_read(f5, read_id, results_path) + + actual_label = correct_results[read_id]["label"] + actual_prob = correct_results[read_id]["prob"] + t = clf_config["min_confidence"] + + # print( + # f"read_id: {read_id}\tactual_label: {actual_label}\tpred_label: {pred_label}" + # ) + + assert actual_label == pred_label + assert abs(actual_prob - pred_prob) < 0.02 + assert passed_classification == bool(pred_prob > t) + os.remove(test_f5_fname) + + +def classify_fast5_file_filtered_test(): + # Predict classification result from 9 segment file + # Make assertions about the class & probability + + # Define subset of the config dict that contains config info + clf_config = { + "classifier": "NTER_cnn", + "classifier_path": "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt", + "classifier_version": "1.0", + "start_obs": 100, + "end_obs": 21000, + "min_confidence": 0.9, + } + + # Load a classifier + clf_name = "NTER_cnn" + clf_path = "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt" + classifier = classify.init_classifier(clf_name, clf_path) + classifier_run_name = "NTER_cnn" + results_path = f"/Classification/{classifier_run_name}" + + correct_results = { + "697de4c1-1aef-41b9-ae0d-d676e983cb7e": {"label": 5, "prob": 0.726}, + "8e8181d2-d749-4735-9cab-37648b463f88": {"label": 8, "prob": 0.808}, + "97d06f2e-90e6-4ca7-91c4-084f13b693f2": {"label": 9, "prob": 0.977}, + "a0c40f5a-c685-43b9-a3b7-ca13aa90d832": {"label": 9, "prob": 0.977}, + "ab54dab5-26a7-4062-9d77-f63fc40f702c": {"label": 9, "prob": 0.979}, + "c87905e6-fd62-4ac6-bcbd-c7f17ff4af14": {"label": 9, "prob": 0.977}, + "cd6fa746-e93b-467f-a3fc-1c9af815f836": {"label": 6, "prob": 0.484}, + "df4365f4-bfe4-4d2c-8100-34c36cd11378": {"label": 9, "prob": 0.975}, + "f5d76520-c92b-4a9c-b5cb-a04414db527e": {"label": 6, "prob": 0.254}, + } + + # Prepare file for testing + orig_f5_fname = "tests/data/filter_and_store_result_test.fast5" + test_f5_fname = "classify_fast5_file_filtered_test.fast5" + copyfile(orig_f5_fname, test_f5_fname) + + # Use reads from filtered section, not root reads path + filter_name = "standard filter" + reads_path = f"/Filter/{filter_name}/pass" + + # Classify f5 file directly + with h5py.File(test_f5_fname, "r+") as f5: + classify.classify_fast5_file( + f5, clf_config, classifier, clf_name, reads_path, class_labels=None + ) + + # Evaluate output written to file + with h5py.File(test_f5_fname, "r") as f5: + for grp in f5.get(reads_path): + if "read" in str(grp): + read_id = re.findall(r"read_(.*)", str(grp))[0] + + ( + pred_label, + pred_prob, + assigned_class, + passed_classification, + ) = classify.get_classification_for_read(f5, read_id, results_path) + + actual_label = correct_results[read_id]["label"] + actual_prob = correct_results[read_id]["prob"] + t = clf_config["min_confidence"] + + # print( + # f"read_id: {read_id}\tactual_label: {actual_label}\tpred_label: {pred_label}" + # ) + + assert actual_label == pred_label + assert abs(actual_prob - pred_prob) < 0.02 + assert passed_classification == bool(pred_prob > t) + os.remove(test_f5_fname) + + +def filter_and_classify_test(): + # take config from test_filter.py and add to it + + orig_f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + test_f5_fname = "filter_and_classify_test.fast5" + copyfile(orig_f5_fname, test_f5_fname) + + # Define config dict that contains filter info + config = { + "compute": {"n_workers": 4}, + "filters": { + "base filter": {"length": (100, None)}, + "test filter": {"min": (100, None)}, + }, + "output": {"capture_f5_dir": "tests/", "captures_per_f5": 1000}, + "classify": { + "classifier": "NTER_cnn", + "classifier_path": "poretitioner/utils/model/NTERs_trained_cnn_05152019.statedict.pt", + "classifier_version": "1.0", + "start_obs": 100, + "end_obs": 21000, + "min_confidence": 0.9, + }, + } + + # Predict classification result from 9 segment file + # Make assertions about the class & probability + + correct_results = { + "697de4c1-1aef-41b9-ae0d-d676e983cb7e": {"label": 5, "prob": 0.726}, + "8e8181d2-d749-4735-9cab-37648b463f88": {"label": 8, "prob": 0.808}, + "97d06f2e-90e6-4ca7-91c4-084f13b693f2": {"label": 9, "prob": 0.977}, + "a0c40f5a-c685-43b9-a3b7-ca13aa90d832": {"label": 9, "prob": 0.977}, + "ab54dab5-26a7-4062-9d77-f63fc40f702c": {"label": 9, "prob": 0.979}, + "c87905e6-fd62-4ac6-bcbd-c7f17ff4af14": {"label": 9, "prob": 0.977}, + "cd6fa746-e93b-467f-a3fc-1c9af815f836": {"label": 6, "prob": 0.484}, + "df4365f4-bfe4-4d2c-8100-34c36cd11378": {"label": 9, "prob": 0.975}, + "f5d76520-c92b-4a9c-b5cb-a04414db527e": {"label": 6, "prob": 0.254}, + } + + # Classify f5 file directly + filter_name = "base filter" + classify.filter_and_classify( + config, [test_f5_fname], overwrite=True, filter_name="base filter" + ) + + # Use reads from filtered section, not root reads path + reads_path = f"/Filter/{filter_name}/pass" + results_path = f"/Classification/{config['classify']['classifier']}" + + # Evaluate output written to file + with h5py.File(test_f5_fname, "r") as f5: + for grp in f5.get(reads_path): + if "read" in str(grp): + read_id = re.findall(r"read_(.*)", str(grp))[0] + + ( + pred_label, + pred_prob, + assigned_class, + passed_classification, + ) = classify.get_classification_for_read(f5, read_id, results_path) + + actual_label = correct_results[read_id]["label"] + actual_prob = correct_results[read_id]["prob"] + t = config["classify"]["min_confidence"] + + # print( + # f"read_id: {read_id}\tactual_label: {actual_label}\tpred_label: {pred_label}" + # ) + + assert actual_label == pred_label + assert abs(actual_prob - pred_prob) < 0.02 + assert passed_classification == bool(pred_prob > t) diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 0000000..b726375 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,237 @@ +""" +================ +test_filter.py +================ + +This module contains tests for filter.py functionality. + +""" +import os +import re +from shutil import copyfile + +import h5py +import pytest + +import poretitioner.utils.filter as filter + + +def apply_feature_filters_empty_test(): + """Check for pass when no valid filters are provided.""" + # capture -- mean: 1; stdv: 0; median: 1; min: 1; max: 1; len: 6 + capture = [1, 1, 1, 1, 1, 1] + filters = {} + # No filter given -- pass + pass_filters = filter.apply_feature_filters(capture, filters) + filters = {"not_a_filter": (0, 1)} + # No *valid* filter given -- pass + pass_filters = filter.apply_feature_filters(capture, filters) + assert pass_filters + + +def apply_feature_filters_length_test(): + """Test length filter function.""" + # capture -- mean: 1; stdv: 0; median: 1; min: 1; max: 1; len: 6 + capture = [1, 1, 1, 1, 1, 1] + + # Only length filter -- pass (edge case, inclusive high) + filters = {"length": (0, 6)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert pass_filters + + # Only length filter -- pass (edge case, inclusive low) + filters = {"length": (6, 10)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert pass_filters + + # Only length filter -- fail (too short) + filters = {"length": (8, 10)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert not pass_filters + + # Only length filter -- fail (too long) + filters = {"length": (0, 5)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert not pass_filters + + # Only length filter -- pass (no filter actually given) + filters = {"length": (None, None)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert pass_filters + + +def apply_feature_filters_mean_test(): + """Test mean filter function. stdv, median, min, and max apply similarly.""" + # capture -- mean: 0.5; stdv: 0.07; median: 0.5; min: 0.4; max: 0.6; len: 5 + capture = [0.5, 0.5, 0.6, 0.4, 0.5] + # Only mean filter -- pass + filters = {"mean": (0, 1)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert pass_filters + + # Only mean filter -- fail (too high) + filters = {"mean": (0, 0.4)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert not pass_filters + + # Only mean filter -- fail (too low) + filters = {"mean": (0.6, 1)} + pass_filters = filter.apply_feature_filters(capture, filters) + assert not pass_filters + + +def check_capture_ejection_by_read_test(): + f5_fail = "tests/data/bulk_fast5_dummy.fast5" + assert os.path.exists(f5_fail) + bad_read_id = "akejwoeirjo;ewijr" + with h5py.File(f5_fail, "r") as f5: + with pytest.raises(ValueError) as e: + filter.check_capture_ejection_by_read(f5, bad_read_id) + assert "does not exist in the fast5 file" in e + # TODO implement fast5 writing to file + + +def check_capture_ejection_test(): + """Essentially checks whether a value (end_capture) is close enough (within + a margin of tol_obs) to any value in voltage_ends. + """ + end_capture = 1000 + voltage_ends = [0, 1000, 2000, 3000] + tol_obs = 100 + assert filter.check_capture_ejection(end_capture, voltage_ends, tol_obs=tol_obs) + + end_capture = 1200 + voltage_ends = [0, 1000, 2000, 3000] + tol_obs = 100 + assert not filter.check_capture_ejection(end_capture, voltage_ends, tol_obs=tol_obs) + + end_capture = 3100 + voltage_ends = [0, 1000, 2000, 3000] + tol_obs = 100 + assert not filter.check_capture_ejection(end_capture, voltage_ends, tol_obs=tol_obs) + + +def apply_filters_to_read_test(): + orig_f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + filter_f5_fname = "apply_filters_to_read_test.fast5" + copyfile(orig_f5_fname, filter_f5_fname) + + standard_filter = { + "mean": (0.05, 0.9), + "min": (0.001, 0.9), + "length": (None, 100_000), + "median": (0.05, 0.9), + "stdv": (0.01, 0.5), + "ejected": False, + } + config = {"filters": {"standard filter": standard_filter}} + pass_reads = [ + "697de4c1-1aef-41b9-ae0d-d676e983cb7e", + "8e8181d2-d749-4735-9cab-37648b463f88", + "a0c40f5a-c685-43b9-a3b7-ca13aa90d832", + "cd6fa746-e93b-467f-a3fc-1c9af815f836", + "f5d76520-c92b-4a9c-b5cb-a04414db527e", + ] + + filter_name = "standard filter" + with h5py.File(filter_f5_fname, "r") as f5: + for g in f5.get("/"): + if "read" in g: + read_id = re.findall(r"read_(.*)", str(g))[0] + passed = filter.apply_filters_to_read(config, f5, read_id, filter_name) + if passed: + assert read_id in pass_reads + else: + assert read_id not in pass_reads + os.remove(filter_f5_fname) + + +def filter_and_store_result_test(): + # TODO docstring + # Examine the 9 captures in the test file + # Create a set of filters that removes some and keeps others + # Call fn + # Verify which reads passed the filters + + orig_f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + filter_f5_fname = "filter_and_store_result_test.fast5" + copyfile(orig_f5_fname, filter_f5_fname) + + standard_filter = { + "mean": (0.05, 0.9), + "min": (0.001, 0.9), + "length": (None, 100_000), + "median": (0.05, 0.9), + "stdv": (0.01, None), + "ejected": False, + } + lenient_filter = {"ejected": False} + strict_filter = {"mean": (1, 1), "min": (1, 1), "ejected": True} + config = { + "filters": { + "standard filter": standard_filter, + "lenient filter": lenient_filter, + "strict filter": strict_filter, + } + } + + filter_names = ["standard filter", "lenient filter", "strict filter"] + n_passing_filters = [5, 9, 0] + for filter_name in filter_names: + filter.filter_and_store_result(config, [filter_f5_fname], filter_name, overwrite=True) + + with h5py.File(filter_f5_fname, "r") as f5: + for i, filter_name in enumerate(filter_names): + passing_reads = list(f5.get(f"/Filter/{filter_name}/pass")) + assert len(passing_reads) == n_passing_filters[i] + os.remove(filter_f5_fname) + + +def write_filter_results_test(): + # Copy a tester fast5 file + orig_f5_fname = "tests/data/reads_fast5_dummy_9captures.fast5" + filter_f5_fname = "write_filter_results_test.fast5" + copyfile(orig_f5_fname, filter_f5_fname) + # Define config dict that contains filter info + config = { + "compute": {"n_workers": 4}, + "segment": { + "voltage_threshold": -180, + "signal_threshold": 0.7, + "translocation_delay": 10, + "open_channel_prior_mean": 230, + "open_channel_prior_stdv": 25, + "good_channels": [1, 2, 3], + "end_tol": 0, + "terminal_capture_only": False, + }, + "filters": {"base filter": {"length": (100, None)}, "test filter": {"min": (100, None)}}, + "output": {"capture_f5_dir": "tests/", "captures_per_f5": 1000}, + } + + # Get read ids from the original file, and take some that "passed" our + # imaginary filter. + read_ids = [] + with h5py.File(filter_f5_fname, "r") as f5: + for g in f5.get("/"): + if "read" in str(g): + read_id = re.findall(r"read_(.*)", str(g))[0] + read_ids.append(read_id) + read_ids = read_ids[1 : len(read_ids) : 2] + + # Call fn + filter_name = "test filter" + with h5py.File(filter_f5_fname, "a") as f5: + filter.write_filter_results(f5, config, read_ids, filter_name) + + # Check: + # * len(read_ids) is correct + # * config values are all present (TODO) + with h5py.File(filter_f5_fname, "r") as f5: + g = list(f5.get(f"/Filter/{filter_name}/pass")) + g = [x for x in g if "read" in x] + assert len(g) == len(read_ids) + + os.remove(filter_f5_fname) + # TODO other versions: + # * read id is not present in the file (what behavior?) diff --git a/tests/test_raw_signal_utils.py b/tests/test_raw_signal_utils.py index eea762e..73b6d38 100644 --- a/tests/test_raw_signal_utils.py +++ b/tests/test_raw_signal_utils.py @@ -7,6 +7,7 @@ """ import h5py + import poretitioner.utils.raw_signal_utils as raw_signal_utils @@ -31,8 +32,7 @@ def get_voltage_segment_test(): def unscale_raw_current_test(): - """Test ability to convert back & forth between digital data & pA. - """ + """Test ability to convert back & forth between digital data & pA.""" bulk_f5_fname = "tests/data/bulk_fast5_dummy.fast5" channel_no = 1 with h5py.File(bulk_f5_fname, "r") as f5: diff --git a/tests/test_segment.py b/tests/test_segment.py index 1f881b2..3adc6b8 100644 --- a/tests/test_segment.py +++ b/tests/test_segment.py @@ -10,9 +10,10 @@ import h5py import numpy as np -import poretitioner.utils.segment as segment import pytest +import poretitioner.utils.segment as segment + def create_capture_fast5_test(): """Test valid capture fast5 produced. Conditions: @@ -64,8 +65,7 @@ def create_capture_fast5_overwrite_test(): def create_capture_fast5_bulk_exists_test(): - """Test error thrown when bulk fast5 does not exist. - """ + """Test error thrown when bulk fast5 does not exist.""" bulk_f5_fname = "tests/data/bulk_fast5_dummy_fake.fast5" capture_f5_fname = "tests/data/capture_fast5_dummy_bulkdemo.fast5" config = {} @@ -76,8 +76,7 @@ def create_capture_fast5_bulk_exists_test(): def create_capture_fast5_capture_path_missing_test(): - """Test error thrown when bulk fast5 does not exist. - """ + """Test error thrown when bulk fast5 does not exist.""" bulk_f5_fname = "tests/data/bulk_fast5_dummy_fake.fast5" capture_f5_fname = "tests/DNE/capture_fast5_dummy_bulkdemo.fast5" config = {} @@ -114,74 +113,10 @@ def create_capture_fast5_subrun_test(): os.remove(capture_f5_fname) -def apply_capture_filters_empty_test(): - """Check for pass when no valid filters are provided.""" - # capture -- mean: 1; stdv: 0; median: 1; min: 1; max: 1; len: 6 - capture = [1, 1, 1, 1, 1, 1] - filters = {} - # No filter given -- pass - pass_filters = segment.apply_capture_filters(capture, filters) - filters = {"not_a_filter": (0, 1)} - # No *valid* filter given -- pass - pass_filters = segment.apply_capture_filters(capture, filters) - assert pass_filters - - -def apply_capture_filters_length_test(): - """Test length filter function.""" - # capture -- mean: 1; stdv: 0; median: 1; min: 1; max: 1; len: 6 - capture = [1, 1, 1, 1, 1, 1] - - # Only length filter -- pass (edge case, inclusive high) - filters = {"length": (0, 6)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert pass_filters - - # Only length filter -- pass (edge case, inclusive low) - filters = {"length": (6, 10)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert pass_filters - - # Only length filter -- fail (too short) - filters = {"length": (8, 10)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert not pass_filters - - # Only length filter -- fail (too long) - filters = {"length": (0, 5)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert not pass_filters - - # Only length filter -- pass (no filter actually given) - filters = {"length": (None, None)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert pass_filters - - -def apply_capture_filters_mean_test(): - """Test mean filter function. stdv, median, min, and max apply similarly.""" - # capture -- mean: 0.5; stdv: 0.07; median: 0.5; min: 0.4; max: 0.6; len: 5 - capture = [0.5, 0.5, 0.6, 0.4, 0.5] - # Only mean filter -- pass - filters = {"mean": (0, 1)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert pass_filters - - # Only mean filter -- fail (too high) - filters = {"mean": (0, 0.4)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert not pass_filters - - # Only mean filter -- fail (too low) - filters = {"mean": (0.6, 1)} - pass_filters = segment.apply_capture_filters(capture, filters) - assert not pass_filters - - def find_captures_0_single_capture_test(): data_file = "tests/data/capture_windows/test_data_capture_window_0.txt.gz" data = np.loadtxt(data_file, delimiter="\t", comments="#") - actual_captures = [(33822, 92691)] + actual_captures = [(33822, 92691, True)] signal_threshold_frac = 0.7 alt_open_channel_pA = 230 terminal_capture_only = False @@ -205,7 +140,7 @@ def find_captures_0_single_capture_test(): def find_captures_0_single_capture_terminal_test(): data_file = "tests/data/capture_windows/test_data_capture_window_0.txt.gz" data = np.loadtxt(data_file, delimiter="\t", comments="#") - actual_captures = [(33822, 92691)] + actual_captures = [(33822, 92691, True)] signal_threshold_frac = 0.7 alt_open_channel_pA = 230 terminal_capture_only = True @@ -492,7 +427,7 @@ def find_captures_8_capture_no_open_channel_test(): end_tol=end_tol, ) assert len(captures) == 2 - actual_captures = [(11310, 22098), (26617, 94048)] + actual_captures = [(11310, 22098, False), (26617, 94048, True)] for test_capture in captures: assert test_capture in actual_captures @@ -511,7 +446,7 @@ def parallel_find_captures_test(tmpdir): "end_tol": 0, "terminal_capture_only": False, }, - "filters": {"length": (100, None)}, + "filters": {"base filter": {"length": (100, None)}}, "output": {"capture_f5_dir": "tests/", "captures_per_f5": 1000}, } segment.parallel_find_captures(bulk_f5_fname, config) @@ -529,7 +464,7 @@ def parallel_find_captures_test(tmpdir): assert start_time_local == start_time_bulk # No offset here duration = a.get("duration") - print(d["Signal"]) + # print(d["Signal"]) len_signal = len(d["Signal"][()]) assert len_signal == duration @@ -551,6 +486,7 @@ def write_capture_to_fast5_test(tmpdir): duration = 8000 voltage = -180 open_channel_pA = 229.1 + ejected = True channel_no = 3 offset, rng, digi = -21.0, 3013.53, 8192.0 sampling_rate = 10000 @@ -561,6 +497,7 @@ def write_capture_to_fast5_test(tmpdir): start_time_bulk, start_time_local, duration, + ejected, voltage, open_channel_pA, channel_no,