Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update make_fixed_size to accomodate numpy arrays #7

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions alphafold/model/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ def tf_example_to_features(tf_example: tf.train.Example,

def np_example_to_features(np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
random_seed: int = 0, desired_num_res: int = None) -> FeatureDict:
"""Preprocesses NumPy feature dict using TF pipeline."""
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
if desired_num_res is not None:
num_res = desired_num_res
else:
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)

if 'deletion_matrix_int' in np_example:
Expand Down
8 changes: 5 additions & 3 deletions alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,15 @@ def predict(self,
logging.info('Running predict with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)

# This block is to ensure benchmark timings are accurate. Some blocking is
# already happening when computing get_confidence_metrics, and this ensures
# all outputs are blocked on.
jax.tree_map(lambda x: x.block_until_ready(), result)
result.update(
get_confidence_metrics(result, multimer_mode=self.multimer_mode))
# result.update(
# get_confidence_metrics(result, multimer_mode=self.multimer_mode))
result.update({"plddt": confidence.compute_plddt(
result['predicted_lddt']['logits'])})

logging.info('Output shape was %s',
tree.map_structure(lambda x: x.shape, result))
return result
8 changes: 4 additions & 4 deletions alphafold/model/tf/data_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from alphafold.model.tf import utils
import numpy as np
import tensorflow.compat.v1 as tf

# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
Expand Down Expand Up @@ -413,19 +412,20 @@ def make_masked_msa(protein, config, replace_fraction):
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
num_res, num_templates=0):
"""Guess at the MSA and sequence dimensions to make fixed size."""

pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}

for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == 'extra_cluster_assignment':
continue
shape = v.shape.as_list()
if type(v) ==np.ndarray:
shape = v.shape
else:
shape = v.shape.as_list()
schema = shape_schema[k]
assert len(shape) == len(schema), (
f'Rank mismatch between shape and shape schema for {k}: '
Expand Down