Skip to content

Architecture

Jeff Flatten edited this page Jun 3, 2024 · 40 revisions

Overview

The basic concepts of TMol's architecture are:

  • The PoseStack - This is a collection of Poses (structures) that are the focus of the work that tmol will be performing.
  • The ScoreFunction - A function that will evaluate a PoseStack with one or more ScoreTerms.
  • The Minimizer - A gradient-descent algorithm that modifies degrees-of-freedom of a PoseStack to minimize the value of a ScoreFunction.

PoseStack

A PoseStack represents a set (a batch) of molecular systems. PoseStacks maybe be loaded in from PDBs, or may come in from other sources such as RosettaFold.

Blocks

tmol breaks poses down into sub-structures called 'blocks' (a generalization of the concept of 'residue' from Rosetta). Each block is one of several "block types". Block types comprise a set of atoms, the properties of those atoms, the chemical bonds between those atoms, and the inter-block chemical bonds that will join blocks together. The molecular system records the block type for each block, the coordinates of each atom, and how each block is connected (or not connected) to other blocks.

Each "Pose" in the stack can hold as many residues as desired

Note

tmol will be more efficient in computing scores for systems that are approximately the same size than for systems that have very different sizes) and these blocks can be bonded together into one or more chains.

PoseStack contains a class that holds the residue-type data for its residues. Most users will not interact directly with this class, but it is good to be aware that it is there. This class, PackedBlockTypes, is built from a collection of RefinedResidueType objects that are constructed either from tmol's default set of residue types, described in a .yaml file, or that are created programmatically by enterprising developers. Most users will be content to interact with the PackedBlockTypes object returned by tmol.default_packed_block_types() if they want to think about this class at all. The PackedBlockTypes object will be used to cache energy-term-specific data that is needed during energy evaluation, and the creation of this data can be somewhat slow; thus, it is most efficient to share a single PackedBlockTypes object between multiple PoseStacks.

The coordinates of a PoseStack can be modified after construction; however, all other data members must be left unaltered. If you want to modify the residue type information for an existing PoseStack, you should construct a new PoseStack object instead.

A Note on Hydrogen Atoms

tmol, and Rosetta before it, creates an all-atom representation of a molecular system. For this reason, there are several differences between tmol's representation of several residue types and the representations for most of the popular ML models (such as AlphaFold/RosettaFold/OpenFold/ESMFold). In particular, tmol models hydrogens explicitly and is thus aware of chemical differences that other representations gloss over. For example, cysteine can either form or not form a disulfide bond with another cysteine. In other ML models, there is no representational difference between cysteine in these two states; in tmol, the disulfide-bonded cysteine has an extra inter-residue connection between the SG sulfur and the other cysteine, and the non-disulfide-bonded cysteine has a sulfhydryl hydrogen in its place. Thus tmol represents these two states with two different residue types. In another case, tmol differentiates between the two tautomers of histidine; one tautomer protonates the NE2 nitrogen in the imidazole ring, the other tautomer ("HIS_D") protonates the ND1 nitrogen.

tmol will build hydrogens for you if you do not have their coordinates. This calculation is deterministic. For aliphatic hydrogens and most polar hydrogens, it is trivially deterministic; for some polar hydrogens, however, there are degrees of freedom that go beyond heavy-atom coordinates. tmol places hydroxyl, phenolic, and sulfhydryl hydrogens with a dihedral angle of 180 -- almost certainly not the optimal location for these atoms -- and it will always choose the NE2-protonated histidine tautomer. The hydrogen-placement step is differentiable, so if you include the tmol energy in your ML model's loss and you only provide heavy-atom coordinates, then the energetic contribution of the automatically-built hydrogen atoms will feed into the positional derivatives of the heavy-atoms that define their geometries.

A Note on Cropping

tmol includes explicit representations of termini atoms, using different residue types for amino acids (and eventually other polymer subunits) that are in the middle of a peptide chain than at the ends. Also, tmol will define chemical bonds between sequential residues that are part of the same chain. In several modeling tasks, the first and last residues of a chain may be absent or sequential residues will not actually have chemical bonds between them. For example, in modeling a protein/protein interface, it might be most computationally efficient to only represent a small number of residues on either side of the interface. The first residue might not be the N terminus, and residues i and i+1 might be separated by many cropped-out residues between them; in such cases, adding the formal positive charge on the first residue's amino group might create unrealistic electrostatic interactions, and declaring a chemical bond between i and i+1 might put large forces on these residues to try and correct the "bad" covalent geometry. tmol's PoseStack construction process allows control over which residues are treated like regular polymeric positions and which are "exceptions to the rule" through the variable res_not_connected. More on this variable below.

The Canonical Form and Class CanonicalOrdering

Because tmol represents structures with a higher chemical granularity than most other ML packages, it has to resolve the chemical structure of the molecules coming in to it from other sources. Even the process of reading in a PDB file requires this chemical type resolution. The input to tmol's PoseStack construction function is

  1. a CanoncicalOrdering object - This represents an mapping of residues to integers and also a mapping for each residue of the atom names to unique integers.
  2. a PackedBlockTypes object - This is an object containing information about the block (residue) types for a PoseStack. The 'Packed' refers to the fact that this data has been packed into tensors so that it can be processed efficiently on the GPU.
  3. a set of three or more tensors typically bundled together in a dictionary referred to as the "canonical form"

The canonical form dictionary must contain:

  1. "chain_id": a tensor of torch.int32 of size [n_poses x max_n_residues], specifying the chain identifier for each residue in each pose.
  2. "res_types": a tensor of torch.int32 of size [n_poses x max_n_residues], specifying the integer representation of each residue's three-letter code in line with the ordering specified by the CanonicalOrdering object, where masked-out residues are indicated by a sentinel value of -1, and
  3. "coords": a tensor of torch.int32 of size [n_poses x max_n_residues x max_n_atoms_per_residue x 3], where the position in the third dimension is used to indicate which atom is being described, and where atoms that are not being given to tmol should have their coordinates given as numpy.NaN

In addition, the canonical form may also contain:

  1. "disulfides": a tensor of torch.int64 of size [n_dslf x 3] which lists disulfides as tuples of (pose_index, res1_index, res2_index). In many modeling problems, the indices of disulfide-bonded residues is known up front and can be given to tmol to avoid the step of detecting disulfide bonds based on distance. There are two reasons to skip this step: 1) it is possible that a model might not place two disulfide-bonded residues close enough together for tmol to declare them to be disulfide-bonded, and thus tmol will be of no help in pushing these residues closer together and 2) this step takes place on the CPU.

  2. "find_additional_disulfides": a boolean that controls whether or not the disulfide-detection step should be performed

  3. "res_not_connected": a tensor torch.bool of size [n_poses x max_n_residues x 2]. This tensor is used to indicate that a given (polymeric) residue is not connected to its previous (position 0) or next (position 1) residue; for termini residues, a value of True will cause the residue to not be built with its down (position 0) or up (position 1) termini-variant types. The purpose is to allow the user to include a subset of the residues in a protein where a series of "gap" residues can be omitted between i and i+1 without those two residues being treated as if they are chemically bonded. This will keep the Ramachandran term from scoring nonsense dihdral angles and will keep the cart-bonded term from scoring nonsense bond lengths and angles.

  4. "return_chain_ind": a boolean that when True alters the return type of this function so that it will be a tuple with the first element being the PoseStack and the second element being a tensor of torch.int32 for the re-indexed residues of the PoseStack. There are two things that should be noted. 1. PoseStack does not keep track of a chain identifier; chain is essentially an emergent property of the chemical bonds. However, PoseStack can be used to represent disconnected segments of a single chain, in which case, it seems that chain identifier cannot be perfectly recovered from the set of chemical bonds. At the moment, if you wish to keep track of the chain identifier for a particular residue, that must be stored separately from the PoseStack. 2. Keeping track of chain identifier is made more challenging by the fact that PoseStack construction will excise out any residues with a residue type of -1 (i.e. gap residues), and all residues appearing after those gap residues are given "new indices" (that is, they will appear earlier in the list of non-gap residues). For convenience, this function returns the chain_id after the gap residues have been removed.

  5. "return_atom_mapping": a boolean that when True alters the return type of this function to that it will be a tuple with the first element being the PoseStack and the last two elements being tensors t1 and t2 that describe the mapping for atoms in the canonical-form tensor to their PoseStack index; this could be used to update the coordinates in a PoseStack without rebuilding it (as long as the chemical identity is meant to be unchanged) or to perhaps remap derivatives to or from pose stack ordering. If requested, the atom mapping will be the last two arguments returned by this function, as two tensors:

            ps, t1, t2 = pose_stack_from_canonical_form(
                ...,
                return_atom_mapping=True
            )
            can_ord_coords[
                t1[:, 0], t1[:, 1], t1[:, 2]
            ] = ps.coords[
                t2[:, 0], t2[:, 1]
            ]
    

    where t1 is a tensor nats x 3 where

    • position [i, 0] is the pose index
    • position [i, 1] is the residue index, and
    • position [i, 2] is the canonical-ordering atom index

    and t2 is a tensor nats x 2 where

    • position [i, 0] is the pose index, and
    • position [i, 1] is the pose-ordered atom index

A Note on Atom Names and Residue Type Resolution

(Describe tmol's residue-type-resolution logic)

Note that tmol.pose_stack_from_rosettafold2 has to strip out the "H" atom from the N-terminal residue as that atom truly ought to be named "1H."

ScoreFunction

The ScoreFunction calculates the energies of a PoseStack using a weighted sum of ScoreTerms. Many ScoreTerms are provided, some of which represent physical forces like electrostatics and van der Waals' interactions, while others represent statistical terms like the probability of finding the torsion angles in Ramachandran space. New score terms may also be added.

When a ScoreFunction evaluates a PoseStack, each ScoreTerm breaks down the work on a per-block or per-block-pair basis, depending on if the ScoreTerm is a 1-body or 2-body term. This per-block and per-block-pair work is all dispatched simultaneously, letting the GPU handle the scheduling of the individual block-based calculations.

Torch vs C++

ScoreTerms may be coded entirely in pure torch (such as the RefEnergyTerm), or may be coded in C++ to enhance performance (HBondEnergyTerm).

For C++ ScoreTerms, both a CPU and a CUDA version are compiled. This compilation is done Just-In-Time (JIT) by Ninja when used. Every ScoreTerm is set up with boilerplate code such that the implementation code is shared between both the CPU and CUDA versions. Note that this means implementation code may only use functions that are available both in C++17 and CUDA (critically, things like std::cout are missing).

Warning

There is currently a bug in the CUDA compilation where the JIT compiling may fail to recognize updates to the code. If you notice a difference between the behavior of your C++ and CUDA implementations, you may need to delete the local cached object files to force a recompile.

C++ functions are exported to Python by pybind. The arguments for the exported functions must all be passed as Tensors or Ints. On the C++ side, the arguments will be unpacked and any tensors will be assigned their associated dimensionality and datatypes by the TCAST function. This function infers the tensor's dimensionality and datatypes based on the argument type of the function it is being sent to.

Warning

It is critical that the order, dimensionality and type of the tensor arguments match in both the calling Python code and the C++ function which consumes them. Any mistakes here may cause very hard to read errors.

Whole-Pose vs Block-Pair scoring

Two varieties of ScoreFunction are provided: whole-pose and block-pair scoring. In whole-pose scoring, energies are provided in sum for the whole of each structure in the PoseStack. In block-pair scoring, energies are provided between every pair of blocks in each of the structures as a NxN tensor. In block-pair scoring, one-body terms still return an NxN tensor with values being placed only on the diagonal.

Derivatives

Derivatives in tmol are calculated either automatically by torch for pure-torch score terms, or by a separate backward pass function for C++ score terms.

Clone this wiki locally