diff --git a/alphapulldown/create_custom_template_db.py b/alphapulldown/create_custom_template_db.py index 9099d20c..798b6afa 100644 --- a/alphapulldown/create_custom_template_db.py +++ b/alphapulldown/create_custom_template_db.py @@ -15,9 +15,8 @@ from pathlib import Path from absl import logging, flags, app from alphapulldown.remove_clashes_low_plddt import MmcifChainFiltered -from colabfold.batch import validate_and_fix_mmcif, convert_pdb_to_mmcif +from colabfold.batch import validate_and_fix_mmcif from alphafold.common.protein import _from_bio_structure, to_mmcif -from Bio import SeqIO FLAGS = flags.FLAGS @@ -92,40 +91,6 @@ def create_tree(pdb_mmcif_dir, mmcif_dir, seqres_dir, templates_dir): create_dir_and_remove_files(seqres_dir, ['pdb_seqres.txt']) -def extract_seqs(template, chain_id): - """ - Extract sequences from PDB/CIF file using Bio.SeqIO. - o input_file_path - path to the input file - o chain_id - chain ID - Returns: - o sequence_atom - sequence from ATOM records - o sequence_seqres - sequence from SEQRES records - """ - file_type = template.suffix.lower() - - if template.suffix.lower() != '.pdb' and template.suffix.lower() != '.cif': - raise ValueError(f"Unknown file type for {template}!") - - format_types = [f"{file_type[1:]}-atom", f"{file_type[1:]}-seqres"] - # initialize the sequences - sequence_atom = None - sequence_seqres = None - # parse - for format_type in format_types: - for record in SeqIO.parse(template, format_type): - chain = record.annotations['chain'] - if chain == chain_id: - if format_type.endswith('atom'): - sequence_atom = str(record.seq) - elif format_type.endswith('seqres'): - sequence_seqres = str(record.seq) - if sequence_atom is None: - logging.error(f"No atom sequence found for chain {chain_id}") - if sequence_seqres is None: - logging.warning(f"No SEQRES sequence found for chain {chain_id}") - return sequence_atom, sequence_seqres - - def create_db(out_path, templates, chains, threshold_clashes, hb_allowance, plddt_threshold): """ Main function that creates a custom template database for AlphaFold2 @@ -155,26 +120,13 @@ def create_db(out_path, templates, chains, threshold_clashes, hb_allowance, pldd shutil.copyfile(template, new_template) template = new_template logging.info(f"Processing template: {template} Chain {chain_id}") - logging.info("Parsing SEQRES...") - atom_seq, seqres_seq = None, None - if template.suffix == '.pdb': - atom_seq, seqres_seq = extract_seqs(template, chain_id) - logging.info(f"Converting to mmCIF: {template}") - template = Path(template) - convert_pdb_to_mmcif(template) - template = template.parent.joinpath(f"{template.stem}.cif") # Convert to (our) mmcif object mmcif_obj = MmcifChainFiltered(template, code, chain_id) - # Parse SEQRES + # full sequence is either SEQRES or parsed from (original) ATOMs if mmcif_obj.sequence_seqres: seqres = mmcif_obj.sequence_seqres else: seqres = mmcif_obj.sequence_atom - # if we converted from pdb, seqres is parsed from Bio.SeqIO - if seqres_seq or atom_seq: - seqres = seqres_seq - if seqres is None: - seqres = atom_seq sqrres_path = save_seqres(code, chain_id, seqres, seqres_dir) logging.info(f"SEQRES saved to {sqrres_path}!") # Remove clashes and low pLDDT regions for each template diff --git a/alphapulldown/remove_clashes_low_plddt.py b/alphapulldown/remove_clashes_low_plddt.py index f84d2d18..cfa6f40b 100755 --- a/alphapulldown/remove_clashes_low_plddt.py +++ b/alphapulldown/remove_clashes_low_plddt.py @@ -1,4 +1,6 @@ from collections import defaultdict +from pathlib import Path + from absl import app, flags import logging import copy @@ -6,7 +8,41 @@ from alphafold.common.residue_constants import residue_atoms from Bio.PDB import Structure, NeighborSearch, PDBIO, MMCIFIO from Bio.PDB.Polypeptide import protein_letters_3to1 +from Bio import SeqIO +from colabfold.batch import convert_pdb_to_mmcif + +def extract_seqs(template, chain_id=None): + """ + Extract sequences from PDB/CIF file using Bio.SeqIO. + o input_file_path - path to the input file + o chain_id - chain ID + Returns: + o sequence_atom - sequence from ATOM records + o sequence_seqres - sequence from SEQRES records + """ + file_type = template.suffix.lower() + if template.suffix.lower() != '.pdb' and template.suffix.lower() != '.cif': + raise ValueError(f"Unknown file type for {template}!") + + format_types = [f"{file_type[1:]}-atom", f"{file_type[1:]}-seqres"] + # initialize the sequences + sequence_atom = None + sequence_seqres = None + # parse + for format_type in format_types: + for record in SeqIO.parse(template, format_type): + chain = record.annotations['chain'] + if chain == chain_id: + if format_type.endswith('atom'): + sequence_atom = str(record.seq) + elif format_type.endswith('seqres'): + sequence_seqres = str(record.seq) + if sequence_atom is None: + logging.error(f"No atom sequence found for chain {chain_id}") + if sequence_seqres is None: + logging.warning(f"No SEQRES sequence found for chain {chain_id}") + return sequence_atom, sequence_seqres class MmcifChainFiltered: """ @@ -20,21 +56,34 @@ class MmcifChainFiltered: def __init__(self, input_file_path, code, chain_id=None): self.input_file_path = input_file_path self.chain_id = chain_id + logging.info("Parsing SEQRES...") + self.sequence_atom, self.sequence_seqres = extract_seqs(input_file_path, chain_id) + if input_file_path.suffix == '.pdb': + logging.info(f"Converting to mmCIF: {input_file_path}") + input_file_path = Path(input_file_path) + convert_pdb_to_mmcif(input_file_path) + input_file_path = input_file_path.parent.joinpath(f"{input_file_path.stem}.cif") with open(input_file_path) as f: mmcif = f.read() parsing_result = parse(file_id=code, mmcif_string=mmcif) if parsing_result.errors: raise Exception(f"Can't parse mmcif file {input_file_path}: {parsing_result.errors}") mmcif_object = parse(file_id=code, mmcif_string=mmcif).mmcif_object - self.sequence_seqres = mmcif_object.chain_to_seqres[chain_id] self.seqres_to_structure = mmcif_object.seqres_to_structure[chain_id] - self.structure, self.sequence_atom = self.extract_chain(mmcif_object.structure, chain_id) + #self.sequence_seqres = mmcif_object.chain_to_seqres[chain_id] + self.structure, sequence_atom = self.extract_chain(mmcif_object.structure, chain_id) + if str(self.sequence_atom) != str(sequence_atom): + logging.info("Template structure was modified!") + logging.info(f"original ATOM sequence: {self.sequence_atom}") + logging.info(f"modified ATOM sequence: {sequence_atom}") + self.sequence_atom = sequence_atom self.structure_modified = False def __eq__(self, other): return self.structure == other.structure + def extract_atom_site_label_seq_id(self): """ Extracts residue index for atoms. @@ -48,6 +97,7 @@ def extract_atom_site_label_seq_id(self): atoms_label_seq_id += [str(label_id + 1)] * number_of_atoms_in_residue return atoms_label_seq_id + def extract_chain(self, model, chain_id): """ Extracts a chain and parses sequence from atoms. diff --git a/test/test_features_with_templates.py b/test/test_features_with_templates.py index 06e59f29..769ebf1c 100644 --- a/test/test_features_with_templates.py +++ b/test/test_features_with_templates.py @@ -4,7 +4,7 @@ import alphapulldown.create_individual_features_with_templates as run_features_generation import pickle import numpy as np -from alphapulldown.create_custom_template_db import extract_seqs +from alphapulldown.remove_clashes_low_plddt import extract_seqs class TestCreateIndividualFeaturesWithTemplates(absltest.TestCase): @@ -71,6 +71,7 @@ def run_features_generation(self, file_name, chain_id, file_extension): temp_sequence = feats['template_sequence'][0].decode('utf-8') target_sequence = feats['sequence'][0].decode('utf-8') atom_coords = feats['template_all_atom_positions'][0] + # Check that template sequence is not empty assert len(temp_sequence) > 0 # Check that the atom coordinates are not all 0 assert (atom_coords.any()) > 0 @@ -83,33 +84,30 @@ def run_features_generation(self, file_name, chain_id, file_extension): print(f"seq-seqres: {seqres_seq}") if seqres_seq: print(len(seqres_seq)) + # SeqIO adds X for missing residues for atom-seq print(f"seq-atom: {atom_seq}") print(len(atom_seq)) - # Check that atoms with non-zero coordinates are identical in seq-seqres and seq-atom + # Check that atoms for not missing residues are not all 0 residue_has_nonzero_coords = [] - atom_id = -1 - for number, (s, a) in enumerate(zip(temp_sequence, atom_coords)): - # if mismatch between target and seqres - if s == '-': + for number, (s, a) in enumerate(zip(atom_seq, atom_coords)): + # no coordinates for missing residues + if s == 'X': assert np.all(a == 0) residue_has_nonzero_coords.append(False) else: non_zero = np.any(a != 0) residue_has_nonzero_coords.append(non_zero) if non_zero: - atom_id += 1 if seqres_seq: seqres = seqres_seq[number] else: seqres = None - print(f"template-seq: {s} atom-seq: {atom_seq[atom_id]} seqres-seq: {seqres} id: {atom_id}") if seqres: - assert (s == seqres_seq[number] or s == atom_seq[atom_id]) #seqres can be different from atomseq - else: - assert (s == atom_seq[atom_id]) + assert (s in seqres_seq) + # first 4 coordinates are non zero assert np.any(a[:4] != 0) - print(residue_has_nonzero_coords) - print(len(residue_has_nonzero_coords)) + #print(residue_has_nonzero_coords) + #print(len(residue_has_nonzero_coords)) def test_1a_run_features_generation(self): self.run_features_generation('3L4Q', 'A', 'cif') @@ -123,7 +121,7 @@ def test_3b_bizarre_filename(self): def test_4c_bizarre_filename(self): self.run_features_generation('RANdom_name1_.7-1_0', 'C', 'pdb') - def test_4c_gappy_pdb(self): + def test_5b_gappy_pdb(self): self.run_features_generation('GAPPY_PDB', 'B', 'pdb') if __name__ == '__main__':