diff --git a/alphafold/model/features.py b/alphafold/model/features.py index c261cef19..062248996 100644 --- a/alphafold/model/features.py +++ b/alphafold/model/features.py @@ -77,10 +77,13 @@ def tf_example_to_features(tf_example: tf.train.Example, def np_example_to_features(np_example: FeatureDict, config: ml_collections.ConfigDict, - random_seed: int = 0) -> FeatureDict: + random_seed: int = 0, desired_num_res: int = None) -> FeatureDict: """Preprocesses NumPy feature dict using TF pipeline.""" np_example = dict(np_example) - num_res = int(np_example['seq_length'][0]) + if desired_num_res is not None: + num_res = desired_num_res + else: + num_res = int(np_example['seq_length'][0]) cfg, feature_names = make_data_config(config, num_res=num_res) if 'deletion_matrix_int' in np_example: diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 072355acd..a48d36d6c 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -165,13 +165,15 @@ def predict(self, logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) - # This block is to ensure benchmark timings are accurate. Some blocking is # already happening when computing get_confidence_metrics, and this ensures # all outputs are blocked on. jax.tree_map(lambda x: x.block_until_ready(), result) - result.update( - get_confidence_metrics(result, multimer_mode=self.multimer_mode)) + # result.update( + # get_confidence_metrics(result, multimer_mode=self.multimer_mode)) + result.update({"plddt": confidence.compute_plddt( + result['predicted_lddt']['logits'])}) + logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result diff --git a/alphafold/model/tf/data_transforms.py b/alphafold/model/tf/data_transforms.py index 7af966ef4..05945dfca 100644 --- a/alphafold/model/tf/data_transforms.py +++ b/alphafold/model/tf/data_transforms.py @@ -20,7 +20,6 @@ from alphafold.model.tf import utils import numpy as np import tensorflow.compat.v1 as tf - # Pylint gets confused by the curry1 decorator because it changes the number # of arguments to the function. # pylint:disable=no-value-for-parameter @@ -413,19 +412,20 @@ def make_masked_msa(protein, config, replace_fraction): def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, num_res, num_templates=0): """Guess at the MSA and sequence dimensions to make fixed size.""" - pad_size_map = { NUM_RES: num_res, NUM_MSA_SEQ: msa_cluster_size, NUM_EXTRA_SEQ: extra_msa_size, NUM_TEMPLATES: num_templates, } - for k, v in protein.items(): # Don't transfer this to the accelerator. if k == 'extra_cluster_assignment': continue - shape = v.shape.as_list() + if type(v) ==np.ndarray: + shape = v.shape + else: + shape = v.shape.as_list() schema = shape_schema[k] assert len(shape) == len(schema), ( f'Rank mismatch between shape and shape schema for {k}: '