Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated AlphaLink2 with AlphaPulldown #200

Merged
merged 42 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
48efe26
update codes
May 20, 2023
2a26316
now make the model structure of unifold flexible with a flag
dingquanyu May 24, 2023
a8eec28
fixed the flags error
dingquanyu May 25, 2023
53b5a76
update gitmodules with unifold submodule
dingquanyu Oct 13, 2023
38e5f1c
start working on adding crosslink input
dingquanyu Oct 13, 2023
39a60a3
added unifold submodule codes
dingquanyu Oct 13, 2023
5e98ed8
added unittest for preparing crosslink infor
dingquanyu Oct 13, 2023
a69c775
update test of process_xl_input functions
dingquanyu Oct 15, 2023
eebac83
finished testing xl input preparation
dingquanyu Oct 15, 2023
4a33322
start working on alphalink inference codes
dingquanyu Oct 15, 2023
2811bf4
finished adding alphalink inference codes
dingquanyu Oct 15, 2023
8406642
Merge pull request #193 from KosinskiLab/main
dingquanyu Oct 26, 2023
855619c
update run_multimer_jobs
dingquanyu Oct 26, 2023
572f9cc
update unifold submodule
dingquanyu Oct 26, 2023
f0dc199
unittest for cross link input and inference with crosslink data
dingquanyu Oct 26, 2023
55608b1
update alphafold, unifold submodules
dingquanyu Nov 1, 2023
fbafbd6
remove unnecessary debug statements
dingquanyu Nov 1, 2023
78c012d
update unifold submodule
dingquanyu Nov 1, 2023
15572fc
update github actions
dingquanyu Nov 1, 2023
479b7d1
update github actions
dingquanyu Nov 1, 2023
636157f
update github actions
dingquanyu Nov 1, 2023
bda5aff
added conda shell.bash hook before testing imports
dingquanyu Nov 1, 2023
beafd98
update github actions
dingquanyu Nov 1, 2023
8bd7f5a
update github actions
dingquanyu Nov 1, 2023
7d522a7
update github actions
dingquanyu Nov 1, 2023
ad75e32
update github actions
dingquanyu Nov 1, 2023
b569722
update github actions
dingquanyu Nov 1, 2023
b14baea
add pip installations
dingquanyu Nov 2, 2023
59caffd
update github actions
dingquanyu Nov 2, 2023
ff4fbc7
update github actions
dingquanyu Nov 2, 2023
299c94f
update github actions
dingquanyu Nov 2, 2023
c8e68e3
update github actions
dingquanyu Nov 2, 2023
1024b97
update github actions
dingquanyu Nov 2, 2023
d4c6a3e
update github actions
dingquanyu Nov 2, 2023
4ab1699
added test data for cross-link input information
dingquanyu Nov 2, 2023
e065104
update alphalink prediction output path
dingquanyu Nov 2, 2023
11d960d
update unifold module
dingquanyu Nov 2, 2023
994d8da
update run_multimer_jobs with the latest updates on alphalink inference
dingquanyu Nov 9, 2023
9d5a66a
update unifold
dingquanyu Nov 9, 2023
728b941
update pae ploting codes for alphalink inference
dingquanyu Nov 9, 2023
8434a59
update unifold
dingquanyu Nov 9, 2023
dc42680
Merge branch 'main' into add-unifold
dingquanyu Nov 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions .github/workflows/github_actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,21 @@ jobs:
conda develop alphapulldown
conda develop alphapulldown/ColabFold
conda develop alphafold
#conda info
#conda list
#python -c "import sys; print(sys.path)"
#echo "Python path: $PYTHONPATH"
python -c "import alphafold; import os; print('Alphafold module is located at:', alphafold.__file__); alphafold_dir = os.path.dirname(alphafold.__file__); print('Contents of the Alphafold directory:', os.listdir(alphafold_dir))"
- name: Install dependencies in unifold setup.py
run: |
eval "$(conda shell.bash hook)"
export WORKDIRPATH=$PWD
conda activate AlphaPulldown
cd $WORKDIRPATH/unifold && python3 setup.py install
- name: Install dependencies in AlphaFold setup.py
run: |
eval "$(conda shell.bash hook)"
export WORKDIRPATH=$PWD
conda activate AlphaPulldown
pip install protobuf==4.23.0
cd $WORKDIRPATH/alphafold && python3 setup.py install
python -c "import tree"
- name: Install Dependencies
run: |
eval "$(conda shell.bash hook)"
Expand All @@ -51,17 +61,25 @@ jobs:
pip install torch --pre -f https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html
pip install jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "absl-py>=0.13.0" dm-haiku "dm-tree>=0.1.6" "h5py>=3.1.0" "ml-collections>=0.1.0" "pandas>=1.3.4" tensorflow "importlib-resources==5.8.0" "nbformat==5.4.0" "py3Dmol==2.0.1" ipython appdirs jupyterlab ipywidgets pytest
# - name: Check imports of submodules
# run : |
# eval "$(conda shell.bash hook)"
# conda activate AlphaPulldown
# # install PyTorch
# pip install torch --pre -f https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html
# # setup unicore
# git clone https://github.com/dptech-corp/Uni-Core.git
# cd Uni-Core
# python3 setup.py install --disable-cuda-ext
# python -c "from unifold.alphalink_inference import alphalink_prediction"
# python -c "from unifold.dataset import process_ap"
# python -c "from unifold.config import model_config"
# python -c "from colabfold.batch import get_queries, unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature, parse_fasta"
- name: Run Tests
run: |
eval "$(conda shell.bash hook)"
conda activate AlphaPulldown
python3 -m unittest discover -s test -t test
- name: Build and Push to Docker Hub
uses: mr-smithers-excellent/docker-build-push@v5
with:
image: dmolodenskiy/AlphaPulldown
registry: docker.io
dockerfile: docker/Dockerfile
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
tags: latest
pytest -s test/test_custom_db.py
pytest -s test/test_remove_clashes_low_plddt.py
# python3 -m unittest test/test_crosslink_input.py

4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@
path = alphapulldown/analysis_pipeline/af2plots
url = https://gitlab.com/gchojnowski/af2plots.git
branch = main
[submodule "unifold"]
path = unifold
url = https://github.com/dingquanyu/Uni-Fold.git
branch = main
2 changes: 1 addition & 1 deletion alphafold
31 changes: 31 additions & 0 deletions alphapulldown/plot_pae.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,34 @@ def plot_pae(seqs: list, order, feature_dir, job_name):
ax1.axvline(t, color="black", linewidth=3.5)
plt.title("ranked_{}".format(i))
plt.savefig(f"{feature_dir}/{job_name}_PAE_plot_ranked_{i}.png")

def plot_pae_from_matrix(seqs,pae_matrix,figure_name=''):
xticks = []
initial_tick = 0
for s in seqs:
initial_tick = initial_tick + len(s)
xticks.append(initial_tick)

xticks_labels = []
for i, t in enumerate(xticks):
xticks_labels.append(str(i + 1))

yticks_labels = []
for s in seqs:
yticks_labels.append(str(len(s)))
fig, ax1 = plt.subplots(1, 1)
# plt.figure(figsize=(3,18))
check = pae_matrix
fig, ax1 = plt.subplots(1, 1)
pos = ax1.imshow(check, cmap="bwr", vmin=0, vmax=30)
ax1.set_xticks(xticks)
ax1.set_yticks(xticks)

ax1.set_xticklabels(xticks_labels, size="large")
ax1.set_yticklabels(yticks_labels,size="large")
fig.colorbar(pos).ax.set_title("unit: Angstrom")
for t in xticks:
ax1.axhline(t, color="black", linewidth=3.5)
ax1.axvline(t, color="black", linewidth=3.5)
plt.title("ranked_{}".format(i))
plt.savefig(figure_name)
78 changes: 64 additions & 14 deletions alphapulldown/run_multimer_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,24 @@
flags.DEFINE_integer(
"msa_depth", None, "Number of sequences to use from the MSA (by default is taken from AF model config)"
)
flags.DEFINE_boolean(
"use_unifold",False,"Whether unifold models are going to be used. Default it False"
)

flags.DEFINE_boolean(
"use_alphalink",False,"Whether alphalink models are going to be used. Default it False"
)
flags.DEFINE_string(
"crosslinks",None,"Path to crosslink information pickle"
)
flags.DEFINE_string(
"alphalink_weight",None,'Path to AlphaLink neural network weights'
)
flags.DEFINE_string(
"unifold_param",None,'Path to UniFold neural network weights'
)
flags.DEFINE_enum("unifold_model_name","multimer_af2",
["multimer_af2","multimer_ft","multimer","multimer_af2_v3","multimer_af2_model45_v3"],"choose unifold model structure")
flags.mark_flag_as_required("output_path")

delattr(flags.FLAGS, "models_to_relax")
Expand Down Expand Up @@ -305,20 +322,53 @@ def predict_individual_jobs(multimer_object, output_path, model_runners, random_
if not isinstance(multimer_object, MultimericObject):
multimer_object.input_seqs = [multimer_object.sequence]


predict(
model_runners,
output_path,
multimer_object.feature_dict,
random_seed,
FLAGS.benchmark,
fasta_name=multimer_object.description,
models_to_relax=FLAGS.models_to_relax,
seqs=multimer_object.input_seqs,
)
create_and_save_pae_plots(multimer_object, output_path)


if FLAGS.use_unifold:
from unifold.inference import config_args,unifold_config_model,unifold_predict
from unifold.dataset import process_ap
from unifold.config import model_config
configs = model_config(FLAGS.unifold_model_name)
general_args = config_args(FLAGS.unifold_param,
target_name=multimer_object.description,
output_dir=output_path)
model_runner = unifold_config_model(general_args)
# First need to add num_recycling_iters to the feature dictionary
# multimer_object.feature_dict.update({"num_recycling_iters":general_args.max_recycling_iters})
processed_features,_ = process_ap(config=configs.data,
features=multimer_object.feature_dict,
mode="predict",labels=None,
seed=42,batch_idx=None,
data_idx=None,is_distillation=False
)
logging.info(f"finished configuring the Unifold AlphlaFold model and process numpy features")
unifold_predict(model_runner,general_args,processed_features)

elif FLAGS.use_alphalink:
assert FLAGS.crosslinks is not None
assert FLAGS.alphalink_weight is not None
from unifold.alphalink_inference import alphalink_prediction

from unifold.config import model_config
logging.info(f"Start using AlphaLink weights and cross-link information")
MODEL_NAME = 'model_5_ptm_af2'
configs = model_config(MODEL_NAME)
alphalink_prediction(multimer_object.feature_dict,
os.path.join(FLAGS.output_path,multimer_object.description),
input_seqs = multimer_object.input_seqs,
param_path = FLAGS.alphalink_weight,
configs = configs,crosslinks=FLAGS.crosslinks,
chain_id_map=multimer_object.chain_id_map)
else:
predict(
model_runners,
output_path,
multimer_object.feature_dict,
random_seed,
FLAGS.benchmark,
fasta_name=multimer_object.description,
models_to_relax=FLAGS.models_to_relax,
seqs=multimer_object.input_seqs,
)
create_and_save_pae_plots(multimer_object, output_path)


def predict_multimers(multimers):
Expand Down
65 changes: 65 additions & 0 deletions test/test_crosslink_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import shutil
import tempfile
import unittest
import sys
import os
import torch
from unifold.modules.alphafold import AlphaFold
from unifold.alphalink_inference import prepare_model_runner
from unifold.alphalink_inference import alphalink_prediction
from unifold.dataset import process_ap
from unifold.config import model_config
from alphapulldown.utils import create
from alphapulldown.run_multimer_jobs import predict_individual_jobs,create_custom_jobs

class _TestBase(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_file_path = os.path.join(os.path.dirname(__file__),"test_data/example_crosslink.pkl.gz")
self.config_data_model_name = 'model_5_ptm_af2'
self.config_alphafold_model_name = 'multimer_af2_crop'

class TestCrosslinkInference(_TestBase):
def setUp(self) -> None:
super().setUp()
self.output_dir = tempfile.mkdtemp()
self.monomer_object_path = os.path.join(os.path.dirname(__file__),"test_data/")
self.protein_list = os.path.join(os.path.dirname(__file__),"test_data/example_crosslinked_pair.txt")
self.alphalink_weight = '/g/alphafold/alphalink_weights/AlphaLink-Multimer_SDA_v3.pt'
self.multimerobject = create_custom_jobs(self.protein_list,self.monomer_object_path,job_index=1,pair_msa=True)[0]

def test1_process_features(self):
"""Test whether the PyTorch model of AlphaLink can be initiated successfully"""
configs = model_config(self.config_data_model_name)
processed_features,_ = process_ap(config=configs.data,
features=self.multimerobject.feature_dict,
mode="predict",labels=None,
seed=42,batch_idx=None,
data_idx=None,is_distillation=False,
chain_id_map = self.multimerobject.chain_id_map,
crosslinks = self.crosslink_file_path
)

def test2_load_AlphaLink_weights(self):
"""This is testing weither loading the PyTorch checkpoint is sucessfull"""
if torch.cuda.is_available():
model_device = 'cuda:0'
else:
model_device = 'cpu'

config = model_config(self.config_alphafold_model_name)
model = AlphaFold(config)
state_dict = torch.load(self.alphalink_weight)["ema"]["params"]
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(model_device)

def test3_test_inference(self):
if torch.cuda.is_available():
model_device = 'cuda:0'
else:
model_device = 'cpu'
model = prepare_model_runner(self.alphalink_weight,model_device=model_device)


if __name__ == '__main__':
unittest.main()
49 changes: 49 additions & 0 deletions test/test_crosslink_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl
from alphafold.data.pipeline_multimer import _FastaChain
import numpy as np
import gzip,pickle
import torch
class TestCreateObjects(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz"
self.asym_id = [1]*10 + [2]*25 + [3]*40
self.chain_id_map = {
"A":_FastaChain(sequence='',description='chain1'),
"B":_FastaChain(sequence='',description='chain2'),
"C":_FastaChain(sequence='',description='chain3')
}
self.bins = torch.arange(0,1.05,0.05)
return super().setUp()

def test1_calculate_offsets(self):
offsets = calculate_offsets(self.asym_id)
offsets = offsets.tolist()
expected_offsets = [0,10,35,75]
self.assertEqual(offsets,expected_offsets)

def test2_create_xl_inputs(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
expected_xl = torch.tensor([[10,35,0.01],
[3,27,0.01],
[5,56,0.01],
[20,65,0.01]])
self.assertTrue(torch.equal(xl,expected_xl))

def test3_bin_xl(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
num_res = len(self.asym_id)
xl = bin_xl(xl,num_res)
expected_xl = np.zeros((num_res,num_res,1))
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins)
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins)
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins)
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins)
self.assertTrue(np.array_equal(xl,expected_xl))

if __name__ == "__main__":
unittest.main()
Binary file added test/test_data/example_crosslink.pkl.gz
Binary file not shown.
Binary file added test/test_data/test_xl_input.pkl.gz
Binary file not shown.
1 change: 1 addition & 0 deletions unifold
Submodule unifold added at 726788