Skip to content

Commit

Permalink
Fix for gappy test pdb
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaMolod committed Nov 13, 2023
1 parent c6c9bef commit 404fd65
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 66 deletions.
52 changes: 2 additions & 50 deletions alphapulldown/create_custom_template_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from pathlib import Path
from absl import logging, flags, app
from alphapulldown.remove_clashes_low_plddt import MmcifChainFiltered
from colabfold.batch import validate_and_fix_mmcif, convert_pdb_to_mmcif
from colabfold.batch import validate_and_fix_mmcif
from alphafold.common.protein import _from_bio_structure, to_mmcif
from Bio import SeqIO

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -92,40 +91,6 @@ def create_tree(pdb_mmcif_dir, mmcif_dir, seqres_dir, templates_dir):
create_dir_and_remove_files(seqres_dir, ['pdb_seqres.txt'])


def extract_seqs(template, chain_id):
"""
Extract sequences from PDB/CIF file using Bio.SeqIO.
o input_file_path - path to the input file
o chain_id - chain ID
Returns:
o sequence_atom - sequence from ATOM records
o sequence_seqres - sequence from SEQRES records
"""
file_type = template.suffix.lower()

if template.suffix.lower() != '.pdb' and template.suffix.lower() != '.cif':
raise ValueError(f"Unknown file type for {template}!")

format_types = [f"{file_type[1:]}-atom", f"{file_type[1:]}-seqres"]
# initialize the sequences
sequence_atom = None
sequence_seqres = None
# parse
for format_type in format_types:
for record in SeqIO.parse(template, format_type):
chain = record.annotations['chain']
if chain == chain_id:
if format_type.endswith('atom'):
sequence_atom = str(record.seq)
elif format_type.endswith('seqres'):
sequence_seqres = str(record.seq)
if sequence_atom is None:
logging.error(f"No atom sequence found for chain {chain_id}")
if sequence_seqres is None:
logging.warning(f"No SEQRES sequence found for chain {chain_id}")
return sequence_atom, sequence_seqres


def create_db(out_path, templates, chains, threshold_clashes, hb_allowance, plddt_threshold):
"""
Main function that creates a custom template database for AlphaFold2
Expand Down Expand Up @@ -155,26 +120,13 @@ def create_db(out_path, templates, chains, threshold_clashes, hb_allowance, pldd
shutil.copyfile(template, new_template)
template = new_template
logging.info(f"Processing template: {template} Chain {chain_id}")
logging.info("Parsing SEQRES...")
atom_seq, seqres_seq = None, None
if template.suffix == '.pdb':
atom_seq, seqres_seq = extract_seqs(template, chain_id)
logging.info(f"Converting to mmCIF: {template}")
template = Path(template)
convert_pdb_to_mmcif(template)
template = template.parent.joinpath(f"{template.stem}.cif")
# Convert to (our) mmcif object
mmcif_obj = MmcifChainFiltered(template, code, chain_id)
# Parse SEQRES
# full sequence is either SEQRES or parsed from (original) ATOMs
if mmcif_obj.sequence_seqres:
seqres = mmcif_obj.sequence_seqres
else:
seqres = mmcif_obj.sequence_atom
# if we converted from pdb, seqres is parsed from Bio.SeqIO
if seqres_seq or atom_seq:
seqres = seqres_seq
if seqres is None:
seqres = atom_seq
sqrres_path = save_seqres(code, chain_id, seqres, seqres_dir)
logging.info(f"SEQRES saved to {sqrres_path}!")
# Remove clashes and low pLDDT regions for each template
Expand Down
54 changes: 52 additions & 2 deletions alphapulldown/remove_clashes_low_plddt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,48 @@
from collections import defaultdict
from pathlib import Path

from absl import app, flags
import logging
import copy
from alphafold.data.mmcif_parsing import parse
from alphafold.common.residue_constants import residue_atoms
from Bio.PDB import Structure, NeighborSearch, PDBIO, MMCIFIO
from Bio.PDB.Polypeptide import protein_letters_3to1
from Bio import SeqIO
from colabfold.batch import convert_pdb_to_mmcif

def extract_seqs(template, chain_id=None):
"""
Extract sequences from PDB/CIF file using Bio.SeqIO.
o input_file_path - path to the input file
o chain_id - chain ID
Returns:
o sequence_atom - sequence from ATOM records
o sequence_seqres - sequence from SEQRES records
"""

file_type = template.suffix.lower()
if template.suffix.lower() != '.pdb' and template.suffix.lower() != '.cif':
raise ValueError(f"Unknown file type for {template}!")

format_types = [f"{file_type[1:]}-atom", f"{file_type[1:]}-seqres"]
# initialize the sequences
sequence_atom = None
sequence_seqres = None
# parse
for format_type in format_types:
for record in SeqIO.parse(template, format_type):
chain = record.annotations['chain']
if chain == chain_id:
if format_type.endswith('atom'):
sequence_atom = str(record.seq)
elif format_type.endswith('seqres'):
sequence_seqres = str(record.seq)
if sequence_atom is None:
logging.error(f"No atom sequence found for chain {chain_id}")
if sequence_seqres is None:
logging.warning(f"No SEQRES sequence found for chain {chain_id}")
return sequence_atom, sequence_seqres

class MmcifChainFiltered:
"""
Expand All @@ -20,21 +56,34 @@ class MmcifChainFiltered:
def __init__(self, input_file_path, code, chain_id=None):
self.input_file_path = input_file_path
self.chain_id = chain_id
logging.info("Parsing SEQRES...")
self.sequence_atom, self.sequence_seqres = extract_seqs(input_file_path, chain_id)
if input_file_path.suffix == '.pdb':
logging.info(f"Converting to mmCIF: {input_file_path}")
input_file_path = Path(input_file_path)
convert_pdb_to_mmcif(input_file_path)
input_file_path = input_file_path.parent.joinpath(f"{input_file_path.stem}.cif")
with open(input_file_path) as f:
mmcif = f.read()
parsing_result = parse(file_id=code, mmcif_string=mmcif)
if parsing_result.errors:
raise Exception(f"Can't parse mmcif file {input_file_path}: {parsing_result.errors}")
mmcif_object = parse(file_id=code, mmcif_string=mmcif).mmcif_object
self.sequence_seqres = mmcif_object.chain_to_seqres[chain_id]
self.seqres_to_structure = mmcif_object.seqres_to_structure[chain_id]
self.structure, self.sequence_atom = self.extract_chain(mmcif_object.structure, chain_id)
#self.sequence_seqres = mmcif_object.chain_to_seqres[chain_id]
self.structure, sequence_atom = self.extract_chain(mmcif_object.structure, chain_id)
if str(self.sequence_atom) != str(sequence_atom):
logging.info("Template structure was modified!")
logging.info(f"original ATOM sequence: {self.sequence_atom}")
logging.info(f"modified ATOM sequence: {sequence_atom}")
self.sequence_atom = sequence_atom
self.structure_modified = False


def __eq__(self, other):
return self.structure == other.structure


def extract_atom_site_label_seq_id(self):
"""
Extracts residue index for atoms.
Expand All @@ -48,6 +97,7 @@ def extract_atom_site_label_seq_id(self):
atoms_label_seq_id += [str(label_id + 1)] * number_of_atoms_in_residue
return atoms_label_seq_id


def extract_chain(self, model, chain_id):
"""
Extracts a chain and parses sequence from atoms.
Expand Down
26 changes: 12 additions & 14 deletions test/test_features_with_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import alphapulldown.create_individual_features_with_templates as run_features_generation
import pickle
import numpy as np
from alphapulldown.create_custom_template_db import extract_seqs
from alphapulldown.remove_clashes_low_plddt import extract_seqs


class TestCreateIndividualFeaturesWithTemplates(absltest.TestCase):
Expand Down Expand Up @@ -71,6 +71,7 @@ def run_features_generation(self, file_name, chain_id, file_extension):
temp_sequence = feats['template_sequence'][0].decode('utf-8')
target_sequence = feats['sequence'][0].decode('utf-8')
atom_coords = feats['template_all_atom_positions'][0]
# Check that template sequence is not empty
assert len(temp_sequence) > 0
# Check that the atom coordinates are not all 0
assert (atom_coords.any()) > 0
Expand All @@ -83,33 +84,30 @@ def run_features_generation(self, file_name, chain_id, file_extension):
print(f"seq-seqres: {seqres_seq}")
if seqres_seq:
print(len(seqres_seq))
# SeqIO adds X for missing residues for atom-seq
print(f"seq-atom: {atom_seq}")
print(len(atom_seq))
# Check that atoms with non-zero coordinates are identical in seq-seqres and seq-atom
# Check that atoms for not missing residues are not all 0
residue_has_nonzero_coords = []
atom_id = -1
for number, (s, a) in enumerate(zip(temp_sequence, atom_coords)):
# if mismatch between target and seqres
if s == '-':
for number, (s, a) in enumerate(zip(atom_seq, atom_coords)):
# no coordinates for missing residues
if s == 'X':
assert np.all(a == 0)
residue_has_nonzero_coords.append(False)
else:
non_zero = np.any(a != 0)
residue_has_nonzero_coords.append(non_zero)
if non_zero:
atom_id += 1
if seqres_seq:
seqres = seqres_seq[number]
else:
seqres = None
print(f"template-seq: {s} atom-seq: {atom_seq[atom_id]} seqres-seq: {seqres} id: {atom_id}")
if seqres:
assert (s == seqres_seq[number] or s == atom_seq[atom_id]) #seqres can be different from atomseq
else:
assert (s == atom_seq[atom_id])
assert (s in seqres_seq)
# first 4 coordinates are non zero
assert np.any(a[:4] != 0)
print(residue_has_nonzero_coords)
print(len(residue_has_nonzero_coords))
#print(residue_has_nonzero_coords)
#print(len(residue_has_nonzero_coords))

def test_1a_run_features_generation(self):
self.run_features_generation('3L4Q', 'A', 'cif')
Expand All @@ -123,7 +121,7 @@ def test_3b_bizarre_filename(self):
def test_4c_bizarre_filename(self):
self.run_features_generation('RANdom_name1_.7-1_0', 'C', 'pdb')

def test_4c_gappy_pdb(self):
def test_5b_gappy_pdb(self):
self.run_features_generation('GAPPY_PDB', 'B', 'pdb')

if __name__ == '__main__':
Expand Down

0 comments on commit 404fd65

Please sign in to comment.