Skip to content

Commit

Permalink
Merge pull request #200 from KosinskiLab/add-unifold
Browse files Browse the repository at this point in the history
Integrated AlphaLink2 inference with AlphaPulldown; Added PAE plotting to AlphaLink2 inference
  • Loading branch information
dingquanyu authored Nov 9, 2023
2 parents 00f2a5b + dc42680 commit ed7218b
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 29 deletions.
46 changes: 32 additions & 14 deletions .github/workflows/github_actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,21 @@ jobs:
conda develop alphapulldown
conda develop alphapulldown/ColabFold
conda develop alphafold
#conda info
#conda list
#python -c "import sys; print(sys.path)"
#echo "Python path: $PYTHONPATH"
python -c "import alphafold; import os; print('Alphafold module is located at:', alphafold.__file__); alphafold_dir = os.path.dirname(alphafold.__file__); print('Contents of the Alphafold directory:', os.listdir(alphafold_dir))"
- name: Install dependencies in unifold setup.py
run: |
eval "$(conda shell.bash hook)"
export WORKDIRPATH=$PWD
conda activate AlphaPulldown
cd $WORKDIRPATH/unifold && python3 setup.py install
- name: Install dependencies in AlphaFold setup.py
run: |
eval "$(conda shell.bash hook)"
export WORKDIRPATH=$PWD
conda activate AlphaPulldown
pip install protobuf==4.23.0
cd $WORKDIRPATH/alphafold && python3 setup.py install
python -c "import tree"
- name: Install Dependencies
run: |
eval "$(conda shell.bash hook)"
Expand All @@ -51,17 +61,25 @@ jobs:
pip install torch --pre -f https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html
pip install jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "absl-py>=0.13.0" dm-haiku "dm-tree>=0.1.6" "h5py>=3.1.0" "ml-collections>=0.1.0" "pandas>=1.3.4" tensorflow "importlib-resources==5.8.0" "nbformat==5.4.0" "py3Dmol==2.0.1" ipython appdirs jupyterlab ipywidgets pytest
# - name: Check imports of submodules
# run : |
# eval "$(conda shell.bash hook)"
# conda activate AlphaPulldown
# # install PyTorch
# pip install torch --pre -f https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html
# # setup unicore
# git clone https://github.com/dptech-corp/Uni-Core.git
# cd Uni-Core
# python3 setup.py install --disable-cuda-ext
# python -c "from unifold.alphalink_inference import alphalink_prediction"
# python -c "from unifold.dataset import process_ap"
# python -c "from unifold.config import model_config"
# python -c "from colabfold.batch import get_queries, unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature, parse_fasta"
- name: Run Tests
run: |
eval "$(conda shell.bash hook)"
conda activate AlphaPulldown
python3 -m unittest discover -s test -t test
- name: Build and Push to Docker Hub
uses: mr-smithers-excellent/docker-build-push@v5
with:
image: dmolodenskiy/AlphaPulldown
registry: docker.io
dockerfile: docker/Dockerfile
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
tags: latest
pytest -s test/test_custom_db.py
pytest -s test/test_remove_clashes_low_plddt.py
# python3 -m unittest test/test_crosslink_input.py
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@
path = alphapulldown/analysis_pipeline/af2plots
url = https://gitlab.com/gchojnowski/af2plots.git
branch = main
[submodule "unifold"]
path = unifold
url = https://github.com/dingquanyu/Uni-Fold.git
branch = main
2 changes: 1 addition & 1 deletion alphafold
31 changes: 31 additions & 0 deletions alphapulldown/plot_pae.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,34 @@ def plot_pae(seqs: list, order, feature_dir, job_name):
ax1.axvline(t, color="black", linewidth=3.5)
plt.title("ranked_{}".format(i))
plt.savefig(f"{feature_dir}/{job_name}_PAE_plot_ranked_{i}.png")

def plot_pae_from_matrix(seqs,pae_matrix,figure_name=''):
xticks = []
initial_tick = 0
for s in seqs:
initial_tick = initial_tick + len(s)
xticks.append(initial_tick)

xticks_labels = []
for i, t in enumerate(xticks):
xticks_labels.append(str(i + 1))

yticks_labels = []
for s in seqs:
yticks_labels.append(str(len(s)))
fig, ax1 = plt.subplots(1, 1)
# plt.figure(figsize=(3,18))
check = pae_matrix
fig, ax1 = plt.subplots(1, 1)
pos = ax1.imshow(check, cmap="bwr", vmin=0, vmax=30)
ax1.set_xticks(xticks)
ax1.set_yticks(xticks)

ax1.set_xticklabels(xticks_labels, size="large")
ax1.set_yticklabels(yticks_labels,size="large")
fig.colorbar(pos).ax.set_title("unit: Angstrom")
for t in xticks:
ax1.axhline(t, color="black", linewidth=3.5)
ax1.axvline(t, color="black", linewidth=3.5)
plt.title("ranked_{}".format(i))
plt.savefig(figure_name)
78 changes: 64 additions & 14 deletions alphapulldown/run_multimer_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,24 @@
flags.DEFINE_integer(
"msa_depth", None, "Number of sequences to use from the MSA (by default is taken from AF model config)"
)
flags.DEFINE_boolean(
"use_unifold",False,"Whether unifold models are going to be used. Default it False"
)

flags.DEFINE_boolean(
"use_alphalink",False,"Whether alphalink models are going to be used. Default it False"
)
flags.DEFINE_string(
"crosslinks",None,"Path to crosslink information pickle"
)
flags.DEFINE_string(
"alphalink_weight",None,'Path to AlphaLink neural network weights'
)
flags.DEFINE_string(
"unifold_param",None,'Path to UniFold neural network weights'
)
flags.DEFINE_enum("unifold_model_name","multimer_af2",
["multimer_af2","multimer_ft","multimer","multimer_af2_v3","multimer_af2_model45_v3"],"choose unifold model structure")
flags.mark_flag_as_required("output_path")

delattr(flags.FLAGS, "models_to_relax")
Expand Down Expand Up @@ -305,20 +322,53 @@ def predict_individual_jobs(multimer_object, output_path, model_runners, random_
if not isinstance(multimer_object, MultimericObject):
multimer_object.input_seqs = [multimer_object.sequence]


predict(
model_runners,
output_path,
multimer_object.feature_dict,
random_seed,
FLAGS.benchmark,
fasta_name=multimer_object.description,
models_to_relax=FLAGS.models_to_relax,
seqs=multimer_object.input_seqs,
)
create_and_save_pae_plots(multimer_object, output_path)


if FLAGS.use_unifold:
from unifold.inference import config_args,unifold_config_model,unifold_predict
from unifold.dataset import process_ap
from unifold.config import model_config
configs = model_config(FLAGS.unifold_model_name)
general_args = config_args(FLAGS.unifold_param,
target_name=multimer_object.description,
output_dir=output_path)
model_runner = unifold_config_model(general_args)
# First need to add num_recycling_iters to the feature dictionary
# multimer_object.feature_dict.update({"num_recycling_iters":general_args.max_recycling_iters})
processed_features,_ = process_ap(config=configs.data,
features=multimer_object.feature_dict,
mode="predict",labels=None,
seed=42,batch_idx=None,
data_idx=None,is_distillation=False
)
logging.info(f"finished configuring the Unifold AlphlaFold model and process numpy features")
unifold_predict(model_runner,general_args,processed_features)

elif FLAGS.use_alphalink:
assert FLAGS.crosslinks is not None
assert FLAGS.alphalink_weight is not None
from unifold.alphalink_inference import alphalink_prediction

from unifold.config import model_config
logging.info(f"Start using AlphaLink weights and cross-link information")
MODEL_NAME = 'model_5_ptm_af2'
configs = model_config(MODEL_NAME)
alphalink_prediction(multimer_object.feature_dict,
os.path.join(FLAGS.output_path,multimer_object.description),
input_seqs = multimer_object.input_seqs,
param_path = FLAGS.alphalink_weight,
configs = configs,crosslinks=FLAGS.crosslinks,
chain_id_map=multimer_object.chain_id_map)
else:
predict(
model_runners,
output_path,
multimer_object.feature_dict,
random_seed,
FLAGS.benchmark,
fasta_name=multimer_object.description,
models_to_relax=FLAGS.models_to_relax,
seqs=multimer_object.input_seqs,
)
create_and_save_pae_plots(multimer_object, output_path)


def predict_multimers(multimers):
Expand Down
65 changes: 65 additions & 0 deletions test/test_crosslink_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import shutil
import tempfile
import unittest
import sys
import os
import torch
from unifold.modules.alphafold import AlphaFold
from unifold.alphalink_inference import prepare_model_runner
from unifold.alphalink_inference import alphalink_prediction
from unifold.dataset import process_ap
from unifold.config import model_config
from alphapulldown.utils import create
from alphapulldown.run_multimer_jobs import predict_individual_jobs,create_custom_jobs

class _TestBase(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_file_path = os.path.join(os.path.dirname(__file__),"test_data/example_crosslink.pkl.gz")
self.config_data_model_name = 'model_5_ptm_af2'
self.config_alphafold_model_name = 'multimer_af2_crop'

class TestCrosslinkInference(_TestBase):
def setUp(self) -> None:
super().setUp()
self.output_dir = tempfile.mkdtemp()
self.monomer_object_path = os.path.join(os.path.dirname(__file__),"test_data/")
self.protein_list = os.path.join(os.path.dirname(__file__),"test_data/example_crosslinked_pair.txt")
self.alphalink_weight = '/g/alphafold/alphalink_weights/AlphaLink-Multimer_SDA_v3.pt'
self.multimerobject = create_custom_jobs(self.protein_list,self.monomer_object_path,job_index=1,pair_msa=True)[0]

def test1_process_features(self):
"""Test whether the PyTorch model of AlphaLink can be initiated successfully"""
configs = model_config(self.config_data_model_name)
processed_features,_ = process_ap(config=configs.data,
features=self.multimerobject.feature_dict,
mode="predict",labels=None,
seed=42,batch_idx=None,
data_idx=None,is_distillation=False,
chain_id_map = self.multimerobject.chain_id_map,
crosslinks = self.crosslink_file_path
)

def test2_load_AlphaLink_weights(self):
"""This is testing weither loading the PyTorch checkpoint is sucessfull"""
if torch.cuda.is_available():
model_device = 'cuda:0'
else:
model_device = 'cpu'

config = model_config(self.config_alphafold_model_name)
model = AlphaFold(config)
state_dict = torch.load(self.alphalink_weight)["ema"]["params"]
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(model_device)

def test3_test_inference(self):
if torch.cuda.is_available():
model_device = 'cuda:0'
else:
model_device = 'cpu'
model = prepare_model_runner(self.alphalink_weight,model_device=model_device)


if __name__ == '__main__':
unittest.main()
49 changes: 49 additions & 0 deletions test/test_crosslink_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl
from alphafold.data.pipeline_multimer import _FastaChain
import numpy as np
import gzip,pickle
import torch
class TestCreateObjects(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz"
self.asym_id = [1]*10 + [2]*25 + [3]*40
self.chain_id_map = {
"A":_FastaChain(sequence='',description='chain1'),
"B":_FastaChain(sequence='',description='chain2'),
"C":_FastaChain(sequence='',description='chain3')
}
self.bins = torch.arange(0,1.05,0.05)
return super().setUp()

def test1_calculate_offsets(self):
offsets = calculate_offsets(self.asym_id)
offsets = offsets.tolist()
expected_offsets = [0,10,35,75]
self.assertEqual(offsets,expected_offsets)

def test2_create_xl_inputs(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
expected_xl = torch.tensor([[10,35,0.01],
[3,27,0.01],
[5,56,0.01],
[20,65,0.01]])
self.assertTrue(torch.equal(xl,expected_xl))

def test3_bin_xl(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
num_res = len(self.asym_id)
xl = bin_xl(xl,num_res)
expected_xl = np.zeros((num_res,num_res,1))
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins)
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins)
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins)
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins)
self.assertTrue(np.array_equal(xl,expected_xl))

if __name__ == "__main__":
unittest.main()
Binary file added test/test_data/example_crosslink.pkl.gz
Binary file not shown.
Binary file added test/test_data/test_xl_input.pkl.gz
Binary file not shown.
1 change: 1 addition & 0 deletions unifold
Submodule unifold added at 726788

0 comments on commit ed7218b

Please sign in to comment.