Skip to content

Architecture

Jeff Flatten edited this page Jun 26, 2024 · 40 revisions

Overview

The central focus of tmol is the PoseStack - a batch of structures. At its heart, tmol is a library for creating, scoring, manipulating, and exporting PoseStacks.

Creating PoseStacks

Under the hood, all PoseStack creation is done through a common function. Other PoseStack creation functions such as loading from a pdb, or importing from RosettaFold2 or OpenFold, work by first converting the source data into a common representation - the CanonicalForm.

CanonicalForm

Because tmol represents structures with a higher chemical granularity than most other ML packages, it has to resolve the chemical structure of the molecules coming into it from other sources. Even the process of reading in a PDB file requires this chemical type resolution.

The CanonicalForm is a structure batch format that lets us represent data in tmol while deferring the chemical resolution. This makes loading the source data into tmol easier, and lets us make the chemical type resolution step use the same machinery, regardless of data source. CanonicalForms are also stable, and can be serialized to disk and loaded in years later to make a PoseStack.

Converting to CanonicalForm

The process of converting input data into a CanonicalForm varies by the input source.

Note

TODO: do we want to say anything about this? Maybe just delete this section?

Chemical Type Resolution

Once a batch of structures has been converted into a CanonicalForm, it can be combined with several other data-source specific objects to resolve the chemical structure and create a PoseStack:

  • The PackedBlockTypes object, which contains chemical information for the set of chemical types used by the data source.
  • The CanonicalOrdering object, which describes the mapping of chemical types to integers and also a mapping for each type of the atom names to unique integers.

Variants of these two objects for each data source are stored in the database, and more can be added in order to support PoseStack creation from new data sources.

Note

It is important to note that the mappings in the CanonicalOrdering object for a data source must not change over time, as these mappings are what preserve the ability for a saved CanonicalForm to be loaded years after it was originally created.

Scoring PoseStacks

tmol can evaluate the energy of a PoseStack using a ScoreFunction that is composed of one or more EnergyTerms. The ScoreFunction will return a weighted sum of these EnergyTerms, using a user-defined set of weights.

ScoreFunction evaluation is separated into several steps:

  1. Precomputation of various EnergyTerm data needed by scoring.
  2. Packing of that data into tensors for efficient use on the GPU.
  3. Rendering of the torch modules.
  4. Calling the rendered ScoreFunction/EnergyTerms on the PoseStack coords.
Precomputation

Before scoring a PoseStack, there is some precompuation that must happen that ensures EnergyTerms have the data they need for every block type in the PoseStack.

The ScoreFunction will have each EnergyTerm iterate over the complete list of RefinedResidueTypes used by any pose in the PoseStack to let them do any preprocessing of block-type specific data that they may need.

This step mostly involves pulling values from the database for the EnergyTerm in question and caching it in the RefinedResidueType.

Packing

After any precomputation is finished, the ScoreFunction will 'pack' this data into an efficient tensor representation that can be transfered to the GPU so that it can be used directly by the torch modules.

This step centers around filling the PackedBlockTypes object. The ScoreFunction will have each EnergyTerm pack any information they may need into this object. This packed data is usually derived from the precomputed data from the first step.ed data and serialize it into compact tensors that are then stored in the PackedblockTypes object.

Rendering a ScoringModule
For EnergyTerms:

In order for torch to actually use our EnergyTerms, we have to create a torch Module. The EnergyTerms use the function render_whole_pose_scoring_module to instantiate a module that is configured for running with the precomputed and packed data.

The ScoringModule itself defines a forward function that does the actual computation on the atom coordinates. This computation can either be pure torch Python code (for an example in code, look at the 'RefEnergyTerm'), or can be written in C++/CUDA ('CartBondedEnergyTerm', 'HBondEnergyTerm', etc).

For the ScoreFucntion:

On the ScoreFunction level, the render function works differently. Rendering a ScoreFunction does not produce an actual torch module like the EnergyTerm's render. Instead, it is in this function that the component EnergyTerm precomputation, packing, and rendering functions are called. The returned value will be a callable configured for the specified PoseStack and Weights.

Note

TODO: Since the ScoreFunction's rendered 'module' isn't a torch module, does this mean it cannot be composed with other torch operations directly? Should we say something about this? Should we change this?

Note

TODO: Some sort of description of how the ScoringModules set up parameters (_p())

Calling the rendered ScoreFunction

The rendered ScoreFunction can be called with the PoseStack's coords tensor to compute the actual weighted-sum evaluation of the EnergyTerms on the conformation of those coordinates.

Note

TODO: Does this mean that the return value from calling the rendered ScoreFunction can be composed like other torch operations or the EnergyTerm modules?

The rendered ScoreFunction can be re-used as long as the only the coordinate positions change (such as during minimization) (TODO: is this strictly true? is there anything else that can change without having to re-render?). Changes to the chemical block types of individual blocks, lengths, weights, etc. will all require re-rendering the ScoreFunction.

Note

TODO: need to mention the difference between pair-wise and whole-pose scoring.

Manipulating PoseStacks

TODO: this whole section could use a lot of work

Minimization

The tmol minimizer uses the derivatives calculated in the EnergyTerms to perform a gradient-decent optimization of the PoseStack's coordinates.

Currently the minimization happens in Cartesian-space, though it should also support kinematic-based minimization in the near future.

TODO: (accuracy, incomplete) The minimizer uses the L-BFGS algorithm with Armijo line search, with scaling and parameters taken from Rosetta. Other functions can theoretically be used, including torch's built in minimization algorithms. A replication of the Rosetta minimization was included because XXX, YYY, and ZZZ.

TODO: Other manipulations? phi/psi setting?

Exporting PoseStacks

Exporting PoseStacks is just the inverse of PoseStack creation. Like creation, going back to the CanonicalForm uses common code, but requires a data-source specific CanonicalOrdering object. From the CanonicalForm, data-source specific code must be written to convert back into the source's data format.

Note

TODO: we should be able to convert to data formats other than the original source, assuming they both have all the required block-types that are present in the PoseStack, right? It seems trivially true for PDBs at least. Might be worth talking about this.

Python, C++, and CUDA

tmol is primarily written in Python, with C++/CUDA being used to write optimized low level code for specific operations (most EnergyTerms, for example). C++ functions are exported to Python by pybind.

When C++/CUDA is used, both a CPU and a CUDA version are compiled. This compilation is done Just-In-Time (JIT) by Ninja when used. tmol makes use of a 'diamond' structure to share the implementation code between C++/CUDA. Note that this means implementation code may only use functions that are available both in C++17 and CUDA (critically, things like std::cout are missing).

Warning

There is currently a bug in the CUDA compilation where the JIT compiling may fail to recognize updates to the code. If you notice a difference between the behavior of your C++ and CUDA implementations, you may need to delete the local cached object files to force a recompile.

Clone this wiki locally