diff --git a/.github/workflows/github_actions.yml b/.github/workflows/github_actions.yml index 24a46cf4..9288dd30 100644 --- a/.github/workflows/github_actions.yml +++ b/.github/workflows/github_actions.yml @@ -35,11 +35,21 @@ jobs: conda develop alphapulldown conda develop alphapulldown/ColabFold conda develop alphafold - #conda info - #conda list - #python -c "import sys; print(sys.path)" - #echo "Python path: $PYTHONPATH" python -c "import alphafold; import os; print('Alphafold module is located at:', alphafold.__file__); alphafold_dir = os.path.dirname(alphafold.__file__); print('Contents of the Alphafold directory:', os.listdir(alphafold_dir))" + - name: Install dependencies in unifold setup.py + run: | + eval "$(conda shell.bash hook)" + export WORKDIRPATH=$PWD + conda activate AlphaPulldown + cd $WORKDIRPATH/unifold && python3 setup.py install + - name: Install dependencies in AlphaFold setup.py + run: | + eval "$(conda shell.bash hook)" + export WORKDIRPATH=$PWD + conda activate AlphaPulldown + pip install protobuf==4.23.0 + cd $WORKDIRPATH/alphafold && python3 setup.py install + python -c "import tree" - name: Install Dependencies run: | eval "$(conda shell.bash hook)" @@ -51,17 +61,25 @@ jobs: pip install torch --pre -f https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html pip install jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install "absl-py>=0.13.0" dm-haiku "dm-tree>=0.1.6" "h5py>=3.1.0" "ml-collections>=0.1.0" "pandas>=1.3.4" tensorflow "importlib-resources==5.8.0" "nbformat==5.4.0" "py3Dmol==2.0.1" ipython appdirs jupyterlab ipywidgets pytest + # - name: Check imports of submodules + # run : | + # eval "$(conda shell.bash hook)" + # conda activate AlphaPulldown + # # install PyTorch + # pip install torch --pre -f https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html + # # setup unicore + # git clone https://github.com/dptech-corp/Uni-Core.git + # cd Uni-Core + # python3 setup.py install --disable-cuda-ext + # python -c "from unifold.alphalink_inference import alphalink_prediction" + # python -c "from unifold.dataset import process_ap" + # python -c "from unifold.config import model_config" + # python -c "from colabfold.batch import get_queries, unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature, parse_fasta" - name: Run Tests run: | eval "$(conda shell.bash hook)" conda activate AlphaPulldown - python3 -m unittest discover -s test -t test - - name: Build and Push to Docker Hub - uses: mr-smithers-excellent/docker-build-push@v5 - with: - image: dmolodenskiy/AlphaPulldown - registry: docker.io - dockerfile: docker/Dockerfile - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - tags: latest + pytest -s test/test_custom_db.py + pytest -s test/test_remove_clashes_low_plddt.py + # python3 -m unittest test/test_crosslink_input.py + diff --git a/.gitmodules b/.gitmodules index f9a8013e..62946fa0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,7 @@ path = alphapulldown/analysis_pipeline/af2plots url = https://gitlab.com/gchojnowski/af2plots.git branch = main +[submodule "unifold"] + path = unifold + url = https://github.com/dingquanyu/Uni-Fold.git + branch = main diff --git a/alphafold b/alphafold index d13534bd..780ffa4d 160000 --- a/alphafold +++ b/alphafold @@ -1 +1 @@ -Subproject commit d13534bd45754ded6190b1b4aa9a00dab52ed75c +Subproject commit 780ffa4db323cf4382f77a04d496c8b6043337a8 diff --git a/alphapulldown/plot_pae.py b/alphapulldown/plot_pae.py index bf12f885..139bf785 100644 --- a/alphapulldown/plot_pae.py +++ b/alphapulldown/plot_pae.py @@ -57,3 +57,34 @@ def plot_pae(seqs: list, order, feature_dir, job_name): ax1.axvline(t, color="black", linewidth=3.5) plt.title("ranked_{}".format(i)) plt.savefig(f"{feature_dir}/{job_name}_PAE_plot_ranked_{i}.png") + +def plot_pae_from_matrix(seqs,pae_matrix,figure_name=''): + xticks = [] + initial_tick = 0 + for s in seqs: + initial_tick = initial_tick + len(s) + xticks.append(initial_tick) + + xticks_labels = [] + for i, t in enumerate(xticks): + xticks_labels.append(str(i + 1)) + + yticks_labels = [] + for s in seqs: + yticks_labels.append(str(len(s))) + fig, ax1 = plt.subplots(1, 1) + # plt.figure(figsize=(3,18)) + check = pae_matrix + fig, ax1 = plt.subplots(1, 1) + pos = ax1.imshow(check, cmap="bwr", vmin=0, vmax=30) + ax1.set_xticks(xticks) + ax1.set_yticks(xticks) + + ax1.set_xticklabels(xticks_labels, size="large") + ax1.set_yticklabels(yticks_labels,size="large") + fig.colorbar(pos).ax.set_title("unit: Angstrom") + for t in xticks: + ax1.axhline(t, color="black", linewidth=3.5) + ax1.axvline(t, color="black", linewidth=3.5) + plt.title("ranked_{}".format(i)) + plt.savefig(figure_name) \ No newline at end of file diff --git a/alphapulldown/run_multimer_jobs.py b/alphapulldown/run_multimer_jobs.py index 738773da..00bb3aef 100755 --- a/alphapulldown/run_multimer_jobs.py +++ b/alphapulldown/run_multimer_jobs.py @@ -66,7 +66,24 @@ flags.DEFINE_integer( "msa_depth", None, "Number of sequences to use from the MSA (by default is taken from AF model config)" ) +flags.DEFINE_boolean( + "use_unifold",False,"Whether unifold models are going to be used. Default it False" +) +flags.DEFINE_boolean( + "use_alphalink",False,"Whether alphalink models are going to be used. Default it False" +) +flags.DEFINE_string( + "crosslinks",None,"Path to crosslink information pickle" +) +flags.DEFINE_string( + "alphalink_weight",None,'Path to AlphaLink neural network weights' +) +flags.DEFINE_string( + "unifold_param",None,'Path to UniFold neural network weights' +) +flags.DEFINE_enum("unifold_model_name","multimer_af2", + ["multimer_af2","multimer_ft","multimer","multimer_af2_v3","multimer_af2_model45_v3"],"choose unifold model structure") flags.mark_flag_as_required("output_path") delattr(flags.FLAGS, "models_to_relax") @@ -305,20 +322,53 @@ def predict_individual_jobs(multimer_object, output_path, model_runners, random_ if not isinstance(multimer_object, MultimericObject): multimer_object.input_seqs = [multimer_object.sequence] - - predict( - model_runners, - output_path, - multimer_object.feature_dict, - random_seed, - FLAGS.benchmark, - fasta_name=multimer_object.description, - models_to_relax=FLAGS.models_to_relax, - seqs=multimer_object.input_seqs, - ) - create_and_save_pae_plots(multimer_object, output_path) - - + if FLAGS.use_unifold: + from unifold.inference import config_args,unifold_config_model,unifold_predict + from unifold.dataset import process_ap + from unifold.config import model_config + configs = model_config(FLAGS.unifold_model_name) + general_args = config_args(FLAGS.unifold_param, + target_name=multimer_object.description, + output_dir=output_path) + model_runner = unifold_config_model(general_args) + # First need to add num_recycling_iters to the feature dictionary + # multimer_object.feature_dict.update({"num_recycling_iters":general_args.max_recycling_iters}) + processed_features,_ = process_ap(config=configs.data, + features=multimer_object.feature_dict, + mode="predict",labels=None, + seed=42,batch_idx=None, + data_idx=None,is_distillation=False + ) + logging.info(f"finished configuring the Unifold AlphlaFold model and process numpy features") + unifold_predict(model_runner,general_args,processed_features) + + elif FLAGS.use_alphalink: + assert FLAGS.crosslinks is not None + assert FLAGS.alphalink_weight is not None + from unifold.alphalink_inference import alphalink_prediction + + from unifold.config import model_config + logging.info(f"Start using AlphaLink weights and cross-link information") + MODEL_NAME = 'model_5_ptm_af2' + configs = model_config(MODEL_NAME) + alphalink_prediction(multimer_object.feature_dict, + os.path.join(FLAGS.output_path,multimer_object.description), + input_seqs = multimer_object.input_seqs, + param_path = FLAGS.alphalink_weight, + configs = configs,crosslinks=FLAGS.crosslinks, + chain_id_map=multimer_object.chain_id_map) + else: + predict( + model_runners, + output_path, + multimer_object.feature_dict, + random_seed, + FLAGS.benchmark, + fasta_name=multimer_object.description, + models_to_relax=FLAGS.models_to_relax, + seqs=multimer_object.input_seqs, + ) + create_and_save_pae_plots(multimer_object, output_path) def predict_multimers(multimers): diff --git a/test/test_crosslink_inference.py b/test/test_crosslink_inference.py new file mode 100644 index 00000000..22623ef9 --- /dev/null +++ b/test/test_crosslink_inference.py @@ -0,0 +1,65 @@ +import shutil +import tempfile +import unittest +import sys +import os +import torch +from unifold.modules.alphafold import AlphaFold +from unifold.alphalink_inference import prepare_model_runner +from unifold.alphalink_inference import alphalink_prediction +from unifold.dataset import process_ap +from unifold.config import model_config +from alphapulldown.utils import create +from alphapulldown.run_multimer_jobs import predict_individual_jobs,create_custom_jobs + +class _TestBase(unittest.TestCase): + def setUp(self) -> None: + self.crosslink_file_path = os.path.join(os.path.dirname(__file__),"test_data/example_crosslink.pkl.gz") + self.config_data_model_name = 'model_5_ptm_af2' + self.config_alphafold_model_name = 'multimer_af2_crop' + +class TestCrosslinkInference(_TestBase): + def setUp(self) -> None: + super().setUp() + self.output_dir = tempfile.mkdtemp() + self.monomer_object_path = os.path.join(os.path.dirname(__file__),"test_data/") + self.protein_list = os.path.join(os.path.dirname(__file__),"test_data/example_crosslinked_pair.txt") + self.alphalink_weight = '/g/alphafold/alphalink_weights/AlphaLink-Multimer_SDA_v3.pt' + self.multimerobject = create_custom_jobs(self.protein_list,self.monomer_object_path,job_index=1,pair_msa=True)[0] + + def test1_process_features(self): + """Test whether the PyTorch model of AlphaLink can be initiated successfully""" + configs = model_config(self.config_data_model_name) + processed_features,_ = process_ap(config=configs.data, + features=self.multimerobject.feature_dict, + mode="predict",labels=None, + seed=42,batch_idx=None, + data_idx=None,is_distillation=False, + chain_id_map = self.multimerobject.chain_id_map, + crosslinks = self.crosslink_file_path + ) + + def test2_load_AlphaLink_weights(self): + """This is testing weither loading the PyTorch checkpoint is sucessfull""" + if torch.cuda.is_available(): + model_device = 'cuda:0' + else: + model_device = 'cpu' + + config = model_config(self.config_alphafold_model_name) + model = AlphaFold(config) + state_dict = torch.load(self.alphalink_weight)["ema"]["params"] + state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()} + model.load_state_dict(state_dict) + model.to(model_device) + + def test3_test_inference(self): + if torch.cuda.is_available(): + model_device = 'cuda:0' + else: + model_device = 'cpu' + model = prepare_model_runner(self.alphalink_weight,model_device=model_device) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_crosslink_input.py b/test/test_crosslink_input.py new file mode 100644 index 00000000..de31716b --- /dev/null +++ b/test/test_crosslink_input.py @@ -0,0 +1,49 @@ +import unittest +from unifold.dataset import calculate_offsets,create_xl_features,bin_xl +from alphafold.data.pipeline_multimer import _FastaChain +import numpy as np +import gzip,pickle +import torch +class TestCreateObjects(unittest.TestCase): + def setUp(self) -> None: + self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz" + self.asym_id = [1]*10 + [2]*25 + [3]*40 + self.chain_id_map = { + "A":_FastaChain(sequence='',description='chain1'), + "B":_FastaChain(sequence='',description='chain2'), + "C":_FastaChain(sequence='',description='chain3') + } + self.bins = torch.arange(0,1.05,0.05) + return super().setUp() + + def test1_calculate_offsets(self): + offsets = calculate_offsets(self.asym_id) + offsets = offsets.tolist() + expected_offsets = [0,10,35,75] + self.assertEqual(offsets,expected_offsets) + + def test2_create_xl_inputs(self): + offsets = calculate_offsets(self.asym_id) + xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb')) + xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map) + expected_xl = torch.tensor([[10,35,0.01], + [3,27,0.01], + [5,56,0.01], + [20,65,0.01]]) + self.assertTrue(torch.equal(xl,expected_xl)) + + def test3_bin_xl(self): + offsets = calculate_offsets(self.asym_id) + xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb')) + xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map) + num_res = len(self.asym_id) + xl = bin_xl(xl,num_res) + expected_xl = np.zeros((num_res,num_res,1)) + expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins) + expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins) + expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins) + expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins) + self.assertTrue(np.array_equal(xl,expected_xl)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_data/example_crosslink.pkl.gz b/test/test_data/example_crosslink.pkl.gz new file mode 100644 index 00000000..ac6bd7bc Binary files /dev/null and b/test/test_data/example_crosslink.pkl.gz differ diff --git a/test/test_data/test_xl_input.pkl.gz b/test/test_data/test_xl_input.pkl.gz new file mode 100644 index 00000000..7ab8d296 Binary files /dev/null and b/test/test_data/test_xl_input.pkl.gz differ diff --git a/unifold b/unifold new file mode 160000 index 00000000..7267885a --- /dev/null +++ b/unifold @@ -0,0 +1 @@ +Subproject commit 7267885a078d5878e4d4dfc25d42ef10af678354