-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #200 from KosinskiLab/add-unifold
Integrated AlphaLink2 inference with AlphaPulldown; Added PAE plotting to AlphaLink2 inference
- Loading branch information
Showing
10 changed files
with
247 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import shutil | ||
import tempfile | ||
import unittest | ||
import sys | ||
import os | ||
import torch | ||
from unifold.modules.alphafold import AlphaFold | ||
from unifold.alphalink_inference import prepare_model_runner | ||
from unifold.alphalink_inference import alphalink_prediction | ||
from unifold.dataset import process_ap | ||
from unifold.config import model_config | ||
from alphapulldown.utils import create | ||
from alphapulldown.run_multimer_jobs import predict_individual_jobs,create_custom_jobs | ||
|
||
class _TestBase(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.crosslink_file_path = os.path.join(os.path.dirname(__file__),"test_data/example_crosslink.pkl.gz") | ||
self.config_data_model_name = 'model_5_ptm_af2' | ||
self.config_alphafold_model_name = 'multimer_af2_crop' | ||
|
||
class TestCrosslinkInference(_TestBase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
self.output_dir = tempfile.mkdtemp() | ||
self.monomer_object_path = os.path.join(os.path.dirname(__file__),"test_data/") | ||
self.protein_list = os.path.join(os.path.dirname(__file__),"test_data/example_crosslinked_pair.txt") | ||
self.alphalink_weight = '/g/alphafold/alphalink_weights/AlphaLink-Multimer_SDA_v3.pt' | ||
self.multimerobject = create_custom_jobs(self.protein_list,self.monomer_object_path,job_index=1,pair_msa=True)[0] | ||
|
||
def test1_process_features(self): | ||
"""Test whether the PyTorch model of AlphaLink can be initiated successfully""" | ||
configs = model_config(self.config_data_model_name) | ||
processed_features,_ = process_ap(config=configs.data, | ||
features=self.multimerobject.feature_dict, | ||
mode="predict",labels=None, | ||
seed=42,batch_idx=None, | ||
data_idx=None,is_distillation=False, | ||
chain_id_map = self.multimerobject.chain_id_map, | ||
crosslinks = self.crosslink_file_path | ||
) | ||
|
||
def test2_load_AlphaLink_weights(self): | ||
"""This is testing weither loading the PyTorch checkpoint is sucessfull""" | ||
if torch.cuda.is_available(): | ||
model_device = 'cuda:0' | ||
else: | ||
model_device = 'cpu' | ||
|
||
config = model_config(self.config_alphafold_model_name) | ||
model = AlphaFold(config) | ||
state_dict = torch.load(self.alphalink_weight)["ema"]["params"] | ||
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()} | ||
model.load_state_dict(state_dict) | ||
model.to(model_device) | ||
|
||
def test3_test_inference(self): | ||
if torch.cuda.is_available(): | ||
model_device = 'cuda:0' | ||
else: | ||
model_device = 'cpu' | ||
model = prepare_model_runner(self.alphalink_weight,model_device=model_device) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl | ||
from alphafold.data.pipeline_multimer import _FastaChain | ||
import numpy as np | ||
import gzip,pickle | ||
import torch | ||
class TestCreateObjects(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz" | ||
self.asym_id = [1]*10 + [2]*25 + [3]*40 | ||
self.chain_id_map = { | ||
"A":_FastaChain(sequence='',description='chain1'), | ||
"B":_FastaChain(sequence='',description='chain2'), | ||
"C":_FastaChain(sequence='',description='chain3') | ||
} | ||
self.bins = torch.arange(0,1.05,0.05) | ||
return super().setUp() | ||
|
||
def test1_calculate_offsets(self): | ||
offsets = calculate_offsets(self.asym_id) | ||
offsets = offsets.tolist() | ||
expected_offsets = [0,10,35,75] | ||
self.assertEqual(offsets,expected_offsets) | ||
|
||
def test2_create_xl_inputs(self): | ||
offsets = calculate_offsets(self.asym_id) | ||
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb')) | ||
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map) | ||
expected_xl = torch.tensor([[10,35,0.01], | ||
[3,27,0.01], | ||
[5,56,0.01], | ||
[20,65,0.01]]) | ||
self.assertTrue(torch.equal(xl,expected_xl)) | ||
|
||
def test3_bin_xl(self): | ||
offsets = calculate_offsets(self.asym_id) | ||
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb')) | ||
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map) | ||
num_res = len(self.asym_id) | ||
xl = bin_xl(xl,num_res) | ||
expected_xl = np.zeros((num_res,num_res,1)) | ||
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins) | ||
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins) | ||
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins) | ||
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins) | ||
self.assertTrue(np.array_equal(xl,expected_xl)) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Binary file not shown.
Binary file not shown.