-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathscore_design.py
140 lines (105 loc) · 4.36 KB
/
score_design.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import argparse
import tempfile
from glob import glob
from tqdm import tqdm
import torch
import numpy as np
from Bio import SeqIO
import deepab
from deepab.models.AbResNet import load_model
from deepab.models.ModelEnsemble import ModelEnsemble
from deepab.analysis.design_metrics import *
from deepab.util.pdb import cdr_indices, pdb2fasta, renumber_pdb, write_pdb_bfactor
cdr_names = ["h1", "h2", "h3", "l1", "l2", "l3"]
branch_names = ["ca", "cb", "no", "omega", "theta", "phi"]
def get_sequence_pairs(h_fasta_file, l_fasta_file):
sequence_pairs = {}
with open(h_fasta_file) as f:
for record in SeqIO.parse(f, "fasta"):
sequence_pairs[record.id] = {"H": str(record.seq)}
with open(l_fasta_file) as f:
for record in SeqIO.parse(f, "fasta"):
sequence_pairs[record.id]["L"] = str(record.seq)
has_mismatch_seq = False
for id, seq_dict in sequence_pairs.items():
if not "H" in seq_dict:
print("Found heavy seq but not light seq for ID {}".format(id))
has_mismatch_seq = True
if not "L" in seq_dict:
print("Found light seq but not heavy seq for ID {}".format(id))
has_mismatch_seq = True
if has_mismatch_seq:
exit("Found mismatched sequences. Exiting.")
return sequence_pairs
def score_designs(model, wt_fasta, h_fasta, l_fasta, device):
sequence_pairs = get_sequence_pairs(h_fasta, l_fasta)
wt_cce = get_fasta_cce(model, wt_fasta, device)
for id, seq_pair in tqdm(sequence_pairs.items(),
total=len(sequence_pairs)):
h_seq, l_seq = seq_pair["H"], seq_pair["L"]
temp_fasta = tempfile.NamedTemporaryFile().name
with open(temp_fasta, "w") as f:
f.write(">:H\n{}\n>:L\n{}\n".format(h_seq, l_seq))
des_cce = get_fasta_cce(model, temp_fasta, device)
dcce = des_cce - wt_cce
seq_pair["dCCE"] = dcce
return sequence_pairs
def _get_args():
"""Gets command line arguments"""
project_path = os.path.abspath(os.path.join(deepab.__file__, "../.."))
desc = ("""
Script for calculating design metrics for antibody Fv sequences.
""")
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"wt_fasta_file",
type=str,
help=
"Fasta file containing wild type Fv heavy and light chain sequences.")
parser.add_argument("heavy_fasta_file",
type=str,
help="Fasta file containing Fv heavy chain sequences.")
parser.add_argument("light_fasta_file",
type=str,
help="Fasta file containing Fv light chain sequences.")
parser.add_argument("out_file",
type=str,
default=None,
help="File to save calculated design metrics.")
default_model_dir = "trained_models/ensemble_abresnet"
parser.add_argument(
"--model_dir",
type=str,
default=default_model_dir,
help="Directory containing pretrained model files (in .pt format).")
parser.add_argument('--use_gpu', default=False, action="store_true")
return parser.parse_args()
def _cli():
args = _get_args()
wt_fasta_file = args.wt_fasta_file
h_fasta_file = args.heavy_fasta_file
l_fasta_file = args.light_fasta_file
out_file = args.out_file
model_dir = args.model_dir
device_type = 'cuda' if torch.cuda.is_available(
) and args.use_gpu else 'cpu'
device = torch.device(device_type)
model_files = list(glob(os.path.join(model_dir, "*.pt")))
if len(model_files) == 0:
exit("No model files found at: {}".format(model_dir))
model = ModelEnsemble(model_files=model_files,
load_model=load_model,
eval_mode=True,
device=device)
sequence_pairs = score_designs(model,
wt_fasta_file,
h_fasta_file,
l_fasta_file,
device=device)
with open(out_file, "w") as f:
for id, seq_pair in sequence_pairs.items():
f.write("{},{},{},{}\n".format(id, seq_pair["H"], seq_pair["L"],
seq_pair["dCCE"]))
if __name__ == '__main__':
_cli()