diff --git a/PopPUNK/assign.py b/PopPUNK/assign.py index 0ea8ea57..a2fc348f 100644 --- a/PopPUNK/assign.py +++ b/PopPUNK/assign.py @@ -107,6 +107,9 @@ def get_options(): queryingGroup.add_argument('--accessory', help='(with a \'refine\' or \'lineage\' model) ' 'Use an accessory-distance only model for assigning queries ' '[default = False]', default=False, action='store_true') + queryingGroup.add_argument('--use-full-network', help='Use full network rather than reference network for querying [default = False]', + default = False, + action = 'store_true') # processing other = parser.add_argument_group('Other options') @@ -235,7 +238,8 @@ def main(): args.gpu_dist, args.gpu_graph, args.deviceid, - args.save_partial_query_graph) + args.save_partial_query_graph, + args.use_full_network) sys.stderr.write("\nDone\n") @@ -268,7 +272,8 @@ def assign_query(dbFuncs, gpu_dist, gpu_graph, deviceid, - save_partial_query_graph): + save_partial_query_graph, + use_full_network): """Code for assign query mode for CLI""" createDatabaseDir = dbFuncs['createDatabaseDir'] constructDatabase = dbFuncs['constructDatabase'] @@ -317,7 +322,8 @@ def assign_query(dbFuncs, accessory, gpu_dist, gpu_graph, - save_partial_query_graph) + save_partial_query_graph, + use_full_network) return(isolateClustering) def assign_query_hdf5(dbFuncs, @@ -342,7 +348,8 @@ def assign_query_hdf5(dbFuncs, accessory, gpu_dist, gpu_graph, - save_partial_query_graph): + save_partial_query_graph, + use_full_network): """Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called by web APIs""" # Modules imported here as graph tool is very slow to load (it pulls in all of GTK?) @@ -360,6 +367,7 @@ def assign_query_hdf5(dbFuncs, from .network import get_vertex_list from .network import printExternalClusters from .network import vertex_betweenness + from .network import retain_only_query_clusters from .qc import sketchlibAssemblyQC from .plot import writeClusterCsv @@ -454,7 +462,7 @@ def assign_query_hdf5(dbFuncs, ref_file_name = os.path.join(model_prefix, os.path.basename(model_prefix) + file_extension_string + ".refs") use_ref_graph = \ - os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' + os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' and not use_full_network if use_ref_graph: with open(ref_file_name) as refFile: for reference in refFile: @@ -792,12 +800,16 @@ def assign_query_hdf5(dbFuncs, output + "/" + os.path.basename(output) + db_suffix) else: storePickle(rNames, qNames, False, qrDistMat, dists_out) - if save_partial_query_graph and not serial: - if model.type == 'lineage': + if save_partial_query_graph: + genomeNetwork, pruned_isolate_lists = retain_only_query_clusters(genomeNetwork, rNames, qNames, use_gpu = gpu_graph) + if model.type == 'lineage' and not serial: save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph) else: graph_suffix = file_extension_string + '_graph' save_network(genomeNetwork, prefix = output, suffix = graph_suffix, use_gpu = gpu_graph) + with open(f"{output}/{os.path.basename(output)}_query.subset",'w') as pruned_isolate_csv: + for isolate in pruned_isolate_lists: + pruned_isolate_csv.write(isolate + '\n') return(isolateClustering) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 2568e934..3f347bde 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1920,28 +1920,8 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g if os.path.exists(network_fn): network_found = True sys.stderr.write("Loading network from " + network_fn + "\n") - samples_to_keep_set = frozenset(samples_to_keep) G = load_network_file(network_fn, use_gpu = use_gpu) - if use_gpu: - # Identify indices - reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set] - # Generate data frame - G_df = G.view_edge_list() - if 'src' in G_df.columns: - G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) - # Filter data frame - G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] - # Translate network indices to match name order - G_new = translate_network_indices(G_new_df, reference_indices) - else: - reference_vertex = G.new_vertex_property('bool') - for n, vertex in enumerate(G.vertices()): - if reflist[n] in samples_to_keep_set: - reference_vertex[vertex] = True - else: - reference_vertex[vertex] = False - G_new = gt.GraphView(G, vfilt = reference_vertex) - G_new = gt.Graph(G_new, prune = True) + G_new = remove_nodes_from_graph(G, reflist, samples_to_keep, use_gpu) save_network(G_new, prefix = output_db_name, suffix = '_graph', @@ -1949,3 +1929,91 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g use_gpu = use_gpu) if not network_found: sys.stderr.write('No network file found for pruning\n') + +def remove_nodes_from_graph(G,reflist, samples_to_keep, use_gpu): + """Return a modified graph containing only the requested nodes + + Args: + reflist (list) + Ordered list of sequences of database + samples_to_keep (list) + The names of samples to be retained in the graph + use_gpu (bool) + Whether graph is a cugraph or not + [default = False] + + Returns: + G_new (graph) + Pruned graph + """ + samples_to_keep_set = frozenset(samples_to_keep) + if use_gpu: + # Identify indices + reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set] + # Generate data frame + G_df = G.view_edge_list() + if 'src' in G_df.columns: + G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) + # Filter data frame + G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + # Translate network indices to match name order + G_new = translate_network_indices(G_new_df, reference_indices) + else: + reference_vertex = G.new_vertex_property('bool') + for n, vertex in enumerate(G.vertices()): + if reflist[n] in samples_to_keep_set: + reference_vertex[vertex] = True + else: + reference_vertex[vertex] = False + G_new = gt.GraphView(G, vfilt = reference_vertex) + G_new = gt.Graph(G_new, prune = True) + return G_new + +def retain_only_query_clusters(G, rlist, qlist, use_gpu = False): + """ + Removes all components that do not contain a query sequence. + + Args: + G (graph) + Network of queries linked to reference sequences + rlist (list) + List of reference sequence labels + qlist (list) + List of query sequence labels + use_gpu (bool) + Whether to use GPUs for network construction + + Returns: + G (graph) + The resulting network + pruned_names (list) + The labels of the sequences in the pruned network + """ + num_refs = len(rlist) + components_with_query = [] + combined_names = rlist + qlist + pruned_names = [] + if use_gpu: + sys.stderr.write('Not compatible with GPU networks yet\n') + query_subgraph = G + else: + components = gt.label_components(G)[0].a + for component in components: + subgraph = gt.GraphView(G, vfilt=components == component) + max_node = max([int(v) for v in subgraph.vertices()]) + if max_node >= num_refs: + components_with_query.append(int(component)) + # Create a boolean filter based on the list of component IDs + query_filter = G.new_vertex_property("bool") + for v in G.vertices(): + query_filter[int(v)] = (components[int(v)] in components_with_query) + if query_filter[int(v)]: + pruned_names.append(combined_names[int(v)]) + + # Create a filtered graph with only the specified components + query_subgraph = gt.GraphView(G, vfilt=query_filter) + + # Purge the filtered graph to remove the other components permanently + query_subgraph.purge_vertices() + + return query_subgraph, pruned_names diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index 40ce0fff..1427cba4 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -90,6 +90,9 @@ def get_options(): iGroup.add_argument('--display-cluster', help='Column of clustering CSV to use for plotting', default=None) + iGroup.add_argument('--use-partial-query-graph', + help='File listing sequences in partial query graph after assignment', + default=None) # output options oGroup = parser.add_argument_group('Output options') @@ -190,6 +193,7 @@ def generate_visualisations(query_db, mst_distances, overwrite, display_cluster, + use_partial_query_graph, tmp): from .models import loadClusterFit @@ -200,6 +204,7 @@ def generate_visualisations(query_db, from .network import cugraph_to_graph_tool from .network import save_network from .network import sparse_mat_to_network + from .network import remove_nodes_from_graph from .plot import drawMST from .plot import outputsForMicroreact @@ -353,9 +358,10 @@ def generate_visualisations(query_db, # extract subset of distances if requested all_seq = combined_seq - if include_files is not None: + if include_files is not None or use_partial_query_graph is not None: viz_subset = set() - with open(include_files, 'r') as assemblyFiles: + subset_file = include_files if include_files is not None else use_partial_query_graph + with open(subset_file, 'r') as assemblyFiles: for assembly in assemblyFiles: viz_subset.add(assembly.rstrip()) if len(viz_subset.difference(combined_seq)) > 0: @@ -605,20 +611,20 @@ def generate_visualisations(query_db, if gpu_graph: genomeNetwork = cugraph_to_graph_tool(genomeNetwork, isolateNameToLabel(all_seq)) # Hard delete from network to remove samples (mask doesn't work neatly) - if viz_subset is not None: - remove_list = [] - for keep, idx in enumerate(row_slice): - if not keep: - remove_list.append(idx) - genomeNetwork.remove_vertex(remove_list) + if include_files is not None: + genomeNetwork = remove_nodes_from_graph(genomeNetwork, all_seq, viz_subset, use_gpu = gpu_graph) elif rank_fit is not None: genomeNetwork = sparse_mat_to_network(sparse_mat, combined_seq, use_gpu = gpu_graph) else: sys.stderr.write('Cytoscape output requires a network file or lineage rank fit to be provided\n') sys.exit(1) + # If network has been pruned then only use the appropriate subset of names - otherwise use all names + # for full network + node_labels = viz_subset if (use_partial_query_graph is not None or include_files is not None) \ + else combined_seq outputsForCytoscape(genomeNetwork, mst_graph, - combined_seq, + node_labels, isolateClustering, output, info_csv) @@ -663,6 +669,7 @@ def main(): args.mst_distances, args.overwrite, args.display_cluster, + args.use_partial_query_graph, args.tmp) if __name__ == '__main__': diff --git a/test/run_test.py b/test/run_test.py index fe496b74..a1da15cc 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -130,7 +130,7 @@ subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model dbscan --ref-db batch12 --output batch12 --overwrite", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db batch12 --output batch12 --overwrite", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch12 --query rfile3.txt --output batch3 --external-clustering batch12_external_clusters.csv --save-partial-query-graph --overwrite", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch12/batch12_graph.gt --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch3/batch3_graph.gt --use-partial-query-graph ./batch3/batch3_query.subset --overwrite", shell=True, check=True) # citations sys.stderr.write("Printing citations\n")