Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes to database pruning and updating #328

Merged
merged 70 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
357ae4c
Move pairwise distance plot to db dir
nickjcroucher Jan 30, 2023
d384335
Add db evaluation histograms
nickjcroucher Jan 30, 2023
17cb947
Correct variable name
nickjcroucher Jan 30, 2023
325171c
Inform user of QC process and output
nickjcroucher Feb 1, 2023
316283a
Fix bracket
nickjcroucher Feb 1, 2023
7041cf7
Fix plot names
nickjcroucher Feb 1, 2023
185c939
Fix plot prefixes
nickjcroucher Feb 1, 2023
bda4398
Validated update of https://github.com/rapidsai/cugraph/pull/2671
nickjcroucher Feb 1, 2023
d25cd5b
Enable subsampling for graph analysis
nickjcroucher Feb 1, 2023
817f7ab
Update triangle counting
nickjcroucher Feb 1, 2023
3e58386
Fix argument parsing
nickjcroucher Feb 1, 2023
17e3af7
Use subgraph for statistics
nickjcroucher Feb 1, 2023
ae49583
Update network count
nickjcroucher Feb 1, 2023
934f8f5
Remove correction factor
nickjcroucher Feb 1, 2023
5c29da1
Fix subsampling for graph-tool
nickjcroucher Feb 1, 2023
16c274a
Add tests for subsampling
nickjcroucher Feb 1, 2023
75ef0d5
Enable sampling for multirefine
nickjcroucher Feb 1, 2023
bde0a3e
Enable extraction of full graph statistics
nickjcroucher Feb 1, 2023
757eeff
Update function docstring
nickjcroucher Feb 2, 2023
7311f71
Allow for tmp file space
nickjcroucher Feb 3, 2023
9e01069
Reduce precision and size of distance matrix
nickjcroucher Feb 3, 2023
c36489a
Make indentation consistent
nickjcroucher Feb 3, 2023
6288446
Prune rank list for rare strains
nickjcroucher Feb 3, 2023
98e36a8
Remove unnecessary file generation
nickjcroucher Feb 3, 2023
9a19cdc
Fix pandas error handling
nickjcroucher Feb 4, 2023
df9117f
Filter duplicate rows in epi csv
nickjcroucher Feb 5, 2023
a4ab201
Remove obsolete variable
nickjcroucher Feb 5, 2023
aa637d7
Harmonise behaviour of BGMM and DBSCAN
nickjcroucher Feb 7, 2023
3306f01
Enable alteration of unconstrained boundary search
nickjcroucher Feb 9, 2023
bf4fcec
Correct sign
nickjcroucher Feb 9, 2023
5652532
Merge pull request #288 from bacpop/master
nickjcroucher Nov 14, 2023
025a818
Resolve conflicts with master
nickjcroucher Feb 29, 2024
f5ec728
Update function call
nickjcroucher Feb 29, 2024
e8cb660
Change network parsing
nickjcroucher Mar 11, 2024
c4ba9d5
Update column names
nickjcroucher Mar 12, 2024
3c4a9cb
Rename columns
nickjcroucher Mar 12, 2024
ec4f8c1
Fix component graph construction
nickjcroucher Mar 13, 2024
986a43d
Fix BFS arg
nickjcroucher Mar 13, 2024
1319225
Updates from master branch
nickjcroucher May 24, 2024
5e63f2e
Enable compatibility with CPU-only mandrake
nickjcroucher May 24, 2024
137895f
Fix processing of sample removal file
nickjcroucher Sep 19, 2024
99a1e83
Add test for pruning database
nickjcroucher Sep 19, 2024
2b6d085
Pass arguments needed for network pruning
nickjcroucher Sep 19, 2024
868fdfc
Add and test prune_database function
nickjcroucher Sep 19, 2024
6e54034
Fix file for graph pruning
nickjcroucher Sep 19, 2024
3cc0aef
Resolve conflicts with master
nickjcroucher Sep 20, 2024
f195985
Resolve conflicts with master
nickjcroucher Sep 20, 2024
95e12c1
Resolve conflicts with master
nickjcroucher Sep 20, 2024
abd4cb6
Add GPU compatibility
nickjcroucher Sep 20, 2024
c97c085
Generalise function to any network
nickjcroucher Sep 20, 2024
9453711
Add missing colon
nickjcroucher Sep 20, 2024
a71505f
Merge branch 'master' into db_pruning_fix
nickjcroucher Sep 20, 2024
bea29b2
Bump version
nickjcroucher Sep 20, 2024
e811b7b
Consistent data structures for updated cython compatibility
nickjcroucher Sep 20, 2024
962b198
Remove assumption of reference database when updating
nickjcroucher Sep 20, 2024
957d4fe
Document network sampling arguments
nickjcroucher Oct 8, 2024
6bbfdb7
Tidy up between strain cluster identification
nickjcroucher Oct 8, 2024
a1691af
Remove obsolete line
nickjcroucher Oct 8, 2024
7d9e17d
Move database processing function into sketchlib.py
nickjcroucher Oct 8, 2024
77537af
Set tmp as default
nickjcroucher Oct 8, 2024
2054e6c
Define adj
nickjcroucher Oct 8, 2024
4d8f1b2
Remove incorrect return description
nickjcroucher Oct 8, 2024
f166cfa
Fix selection of between strain cluster
nickjcroucher Oct 8, 2024
a8b58c0
Improve variable name
nickjcroucher Oct 8, 2024
96e5b5d
Correct function name
nickjcroucher Oct 8, 2024
769834b
Pass prefix to plotting function
nickjcroucher Oct 8, 2024
e56a3e6
Pass correct arguments to plotting function
nickjcroucher Oct 8, 2024
2d311bf
Fix output name processing
nickjcroucher Oct 8, 2024
fce1a21
Fix output name processing again
nickjcroucher Oct 8, 2024
5b4eb67
Update argument name
nickjcroucher Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.7.1'
__version__ = '2.7.2'

# Minimum sketchlib version
SKETCHLIB_MAJOR = 2
Expand Down
33 changes: 25 additions & 8 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_options():
type=float)

# model refinement
refinementGroup = parser.add_argument_group('Refine model options')
refinementGroup = parser.add_argument_group('Network analysis and model refinement options')
refinementGroup.add_argument('--pos-shift', help='Maximum amount to move the boundary right past between-strain mean',
type=float, default = 0)
refinementGroup.add_argument('--neg-shift', help='Maximum amount to move the boundary left past within-strain mean]',
Expand All @@ -156,6 +156,9 @@ def get_options():
refinementGroup.add_argument('--score-idx',
help='Index of score to use [default = 0]',
type=int, default = 0, choices=[0, 1, 2])
refinementGroup.add_argument('--summary-sample',
help='Number of sequences used to estimate graph properties [default = all]',
type=int, default = None)
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved
refinementGroup.add_argument('--betweenness-sample',
help='Number of sequences used to estimate betweeness with a GPU [default = 100]',
type = int, default = betweenness_sample_default)
Expand Down Expand Up @@ -264,6 +267,7 @@ def main():

from .plot import writeClusterCsv
from .plot import plot_scatter
from .plot import plot_database_evaluations

from .qc import prune_distance_matrix, qcDistMat, sketchlibAssemblyQC, remove_qc_fail

Expand Down Expand Up @@ -387,8 +391,9 @@ def main():
# Plot results
if not args.no_plot:
plot_scatter(distMat,
f"{args.output}/{os.path.basename(args.output)}_distanceDistribution",
args.output,
args.output + " distances")
plot_database_evaluations(args.output)

#******************************#
#* *#
Expand Down Expand Up @@ -424,18 +429,18 @@ def main():
if args.remove_samples:
with open(args.remove_samples, 'r') as f:
for line in f:
fail_unconditionally[line.rstrip] = ["removed"]
sample_to_remove = line.rstrip()
if sample_to_remove in refList:
fail_unconditionally[sample_to_remove] = ["removed"]

# assembly qc
sys.stderr.write("Running sequence QC\n")
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved
pass_assembly_qc, fail_assembly_qc = \
sketchlibAssemblyQC(args.ref_db,
refList,
qc_dict)
sys.stderr.write(f"{len(fail_assembly_qc)} samples failed\n")

# QC pairwise distances to identify long distances indicative of anomalous sequences in the collection
sys.stderr.write("Running distance QC\n")
pass_dist_qc, fail_dist_qc = \
qcDistMat(distMat,
refList,
Expand All @@ -446,19 +451,27 @@ def main():

# Get list of passing samples
pass_list = set(refList) - fail_unconditionally.keys() - fail_assembly_qc.keys() - fail_dist_qc.keys()
assert(pass_list == set(refList).intersection(set(pass_assembly_qc)).intersection(set(pass_dist_qc)))
assert(pass_list == (set(refList) - fail_unconditionally.keys()).intersection(set(pass_assembly_qc)).intersection(set(pass_dist_qc)))
passed = [x for x in refList if x in pass_list]
if qc_dict['type_isolate'] is not None and qc_dict['type_isolate'] not in pass_list:
raise RuntimeError('Type isolate ' + qc_dict['type_isolate'] + \
' not found in isolates after QC; check '
'name of type isolate and QC options\n')


sys.stderr.write(f"{len(passed)} samples passed QC\n")
if len(passed) < len(refList):
remove_qc_fail(qc_dict, refList, passed,
[fail_unconditionally, fail_assembly_qc, fail_dist_qc],
args.ref_db, distMat, output,
args.strand_preserved, args.threads)
args.strand_preserved, args.threads,
args.gpu_graph)

# Plot results
if not args.no_plot:
plot_scatter(distMat,
output,
output + " distances")
plot_database_evaluations(output)

#******************************#
#* *#
Expand Down Expand Up @@ -545,6 +558,7 @@ def main():
args.score_idx,
args.no_local,
args.betweenness_sample,
args.summary_sample,
args.gpu_graph)
model = new_model
elif args.fit_model == "threshold":
Expand Down Expand Up @@ -613,6 +627,7 @@ def main():
model.within_label,
distMat = distMat,
weights_type = weights_type,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph)
else:
Expand All @@ -628,6 +643,7 @@ def main():
refList,
assignments[rank],
weights = weights,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph,
summarise = False
Expand Down Expand Up @@ -685,6 +701,7 @@ def main():
queryList,
indivAssignments,
model.within_label,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph)
isolateClustering[dist_type] = \
Expand Down
2 changes: 1 addition & 1 deletion PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def assign_query_hdf5(dbFuncs,
storePickle(combined_seq, combined_seq, True, None, dists_out)

# Clique pruning
if model.type != 'lineage':
if model.type != 'lineage' and os.path.isfile(ref_file_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check this one against #322 – I think I remember changing it there perhaps. Might be best just to remove this change for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it changes the behaviour if there is no reference file, as in the current GPS database - I suppose it just depends what we want the behaviour to be in such a situation.

existing_ref_list = []
with open(ref_file_name) as refFile:
for reference in refFile:
Expand Down
29 changes: 27 additions & 2 deletions PopPUNK/bgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,36 @@ def fit2dMultiGaussian(X, dpgmm_max_K = 2):
return dpgmm


def findBetweenLabel_bgmm(means, assignments, rank = 0):
"""Identify between-strain links

Finds the component with the largest number of points
assigned to it

Args:
means (numpy.array)
K x 2 array of mixture component means
assignments (numpy.array)
Sample cluster assignments
rank (int)
Which label to find, ordered by distance from origin. 0-indexed.
(default = 0)
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved

Returns:
between_label (int)
The cluster label with the most points assigned to it
"""
most_dists = {}
for mixture_component, distance in enumerate(np.apply_along_axis(np.linalg.norm, 1, means)):
most_dists[mixture_component] = np.count_nonzero(assignments == mixture_component)

sorted_dists = sorted(most_dists.items(), key=operator.itemgetter(1), reverse=True)
return(sorted_dists[rank][0])

def findWithinLabel(means, assignments, rank = 0):
"""Identify within-strain links

Finds the component with mean closest to the origin and also akes sure
Finds the component with mean closest to the origin and also makes sure
some samples are assigned to it (in the case of small weighted
components with a Dirichlet prior some components are unused)

Expand All @@ -59,7 +85,6 @@ def findWithinLabel(means, assignments, rank = 0):
Sample cluster assignments
rank (int)
Which label to find, ordered by distance from origin. 0-indexed.

(default = 0)

Returns:
Expand Down
55 changes: 30 additions & 25 deletions PopPUNK/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ def get_options():
# main code
def main():

# Import value
from .__main__ import betweenness_sample_default

# Import functions
from .network import load_network_file
from .network import sparse_mat_to_network
from .network import print_network_summary
from .utils import check_and_set_gpu
from .utils import setGtThreads

Expand Down Expand Up @@ -103,6 +107,32 @@ def main():
use_rc = ref_db['sketches'].attrs['use_rc'] == 1
print("Uses canonical k-mers:\t" + str(use_rc))

# Select network file name
if args.network_file is None:
if use_gpu:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz')
else:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt')
else:
network_file = args.network_file

# Open network file
if network_file.endswith('.gt'):
G = load_network_file(network_file, use_gpu = False)
elif network_file.endswith('.csv.gz'):
if use_gpu:
G = load_network_file(network_file, use_gpu = True)
else:
sys.stderr.write('Unable to load necessary GPU libraries\n')
sys.exit(1)
elif network_file.endswith('.npz'):
sparse_mat = sparse.load_npz(network_file)
G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu)
else:
sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n')
sys.exit(1)
print_network_summary(G, betweenness_sample = betweenness_sample_default, use_gpu = args.use_gpu)

# Print sample information
if not args.simple:
sample_names = list(ref_db['sketches'].keys())
Expand All @@ -115,31 +145,6 @@ def main():
sample_sequence_length[sample_name] = ref_db['sketches/' + sample_name].attrs['length']
sample_missing_bases[sample_name] = ref_db['sketches/' + sample_name].attrs['missing_bases']

# Select network file name
if args.network_file is None:
if use_gpu:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz')
else:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt')
else:
network_file = args.network_file

# Open network file
if network_file.endswith('.gt'):
G = load_network_file(network_file, use_gpu = False)
elif network_file.endswith('.csv.gz'):
if use_gpu:
G = load_network_file(network_file, use_gpu = True)
else:
sys.stderr.write('Unable to load necessary GPU libraries\n')
sys.exit(1)
elif network_file.endswith('.npz'):
sparse_mat = sparse.load_npz(network_file)
G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu)
else:
sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n')
sys.exit(1)

# Analyse network
if use_gpu:
component_assignments_df = cugraph.components.connectivity.connected_components(G)
Expand Down
Loading
Loading