diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 399be8e4..58cfa9fe 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -256,7 +256,7 @@ def main(): # Imports are here because graph tool is very slow to load from .models import loadClusterFit, BGMMFit, DBSCANFit, RefineFit, LineageFit - from .sketchlib import checkSketchlibLibrary, removeFromDB + from .sketchlib import checkSketchlibLibrary, removeFromDB, get_database_statistics from .network import construct_network_from_edge_list from .network import construct_network_from_assignments @@ -393,7 +393,8 @@ def main(): plot_scatter(distMat, args.output, args.output + " distances") - plot_database_evaluations(args.output) + genome_lengths, ambiguous_bases = get_database_statistics(args.output) + plot_database_evaluations(genome_lengths, ambiguous_bases) #******************************# #* *# @@ -471,7 +472,8 @@ def main(): plot_scatter(distMat, output, output + " distances") - plot_database_evaluations(output) + genome_lengths, ambiguous_bases = get_database_statistics(args.output) + plot_database_evaluations(genome_lengths, ambiguous_bases) #******************************# #* *# diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index abb8d441..042e548c 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -16,7 +16,6 @@ # for other outputs import pandas as pd from pandas.errors import DataError -import h5py from collections import defaultdict from sklearn import utils try: # sklearn >= 0.22 @@ -82,21 +81,15 @@ def plot_scatter(X, out_prefix, title, kde = True): plt.savefig(os.path.join(out_prefix, os.path.basename(out_prefix) + '_distanceDistribution.png')) plt.close() -def plot_database_evaluations(prefix): +def plot_database_evaluations(genome_lengths, ambiguous_bases): """Plot histograms of sequence characteristics for database evaluation. Args: - prefix (str) - Prefix of database + genome_lengths (list) + Lengths of genomes in database + ambiguous_bases (list) + Counts of ambiguous bases in genomes in database """ - db_file = prefix + "/" + os.path.basename(prefix) + ".h5" - ref_db = h5py.File(db_file, 'r') - - genome_lengths = [] - ambiguous_bases = [] - for sample_name in list(ref_db['sketches'].keys()): - genome_lengths.append(ref_db['sketches/' + sample_name].attrs['length']) - ambiguous_bases.append(ref_db['sketches/' + sample_name].attrs['missing_bases']) plot_evaluation_histogram(genome_lengths, n_bins = 100, prefix = prefix, diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index c9629df5..7d1bb23a 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -659,3 +659,21 @@ def fitKmerCurve(pairwise, klist, jacobian): # Return core, accessory return(np.flipud(transformed_params)) + +def plot_database_evaluations(prefix): + """Extract statistics for evaluating databases. + + Args: + prefix (str) + Prefix of database + """ + db_file = prefix + "/" + os.path.basename(prefix) + ".h5" + ref_db = h5py.File(db_file, 'r') + + genome_lengths = [] + ambiguous_bases = [] + for sample_name in list(ref_db['sketches'].keys()): + genome_lengths.append(ref_db['sketches/' + sample_name].attrs['length']) + ambiguous_bases.append(ref_db['sketches/' + sample_name].attrs['missing_bases']) + + return genome_lengths, ambiguous_bases