-
Notifications
You must be signed in to change notification settings - Fork 3
Constraints
Constraints in tmol
allow you to define custom score functions in Python, alongside sets of atoms that the this score function will be evaluated on.
Using constraints requires:
- A constraint function. There are several common functions built-in that you may use: harmonic, bounded, and circularharmonic. You may also define your own functions.
- A integer Tensor with the indices of each atom in the constraint. The shape of this will be [n_constraints, n_atoms, 3], where the 3 values in the last dimension represent the [pose_index, residue_index, atom_index] for the atom in question.
- A float Tensor containing the parameters for the function for each constraint. The shape of this will be [n_constraints, n_parameters_per_constraint].
The constraints are attached to the pose_stack by getting the constraint_set from the pose_stack via get_constraint_set
and then calling add_constraints
with these 3 values.
Example:
This code sets up a single harmonic distance constraint between the C of residue 3 of pose 0, and the N of residue 4 of pose 0.
# get the constraint set for the pose_stack (this will create one if it does not exist)
constraints = pose_stack.get_constraint_set()
# create tensors to hold the atom indices of each constraint
# we are creating [1] constraint that is a function of [2] atoms.
cnstr_atoms = torch.full((1, 2, 3), 0, dtype=torch.int32, device=torch_device)
# ... and the parameters for each constraint
# in this case, just [1] constraint with [1] parameter (distance)
cnstr_params = torch.full((1, 1), 0, dtype=torch.float32, device=torch_device)
# now we fill those tensors for each constraint:
# grab the block type for each residue that we want the atoms from
res1_type = pose_stack.block_type(0, 3)
res2_type = pose_stack.block_type(0, 4)
# get the [pose, residue, atom] indices for each atom in our constraint
# and fill them into constraint 0, and atoms 0, 1 of that constraint.
cnstr_atoms[0, 0] = torch.tensor([0, 3, res1_type.atom_to_idx["C"]])
cnstr_atoms[0, 1] = torch.tensor([0, 4, res2_type.atom_to_idx["N"]])
# now fill the distance parameter for this constraint
cnstr_params[0, 0] = 1.47
# add the constraints to the constraint set, using the built-in harmonic function
constraints.add_constraints(
ConstraintEnergyTerm.harmonic, cnstr_atoms, cnstr_params
)
Warning
The constraint set is specific to the pose_stack that it was created for. Constraint sets may be copied to pose_stacks with identical chemical makeup, but you should call get_constraint_set
on that new pose_stack after the copy, and use that reference.
In addition to the built-in functions, you may define your own. These functions must match the following rules:
- The function must take 2 arguments.
- The first argument must be a Tensor of atom coordinates for each set of atoms. The shape of this tensor must be [n_constraints x n_atoms x dimension]. For example, if your function is being applied to 5 different sets of atoms, with 4 atoms per set, this would be a float tensor of 5x4x3.
- The second argument is for parameters to your function. This argument is optional, but must be included in your function signature even if it is unused. If you use this, the parameters must be a Tensor of float values, with a shape of [n_constraints x n_parameters_per_constraint]. An example use of this would be to individually describe distances in an atom pair constraint. If you had 10 atom pair constraints and a single distance parameter for each of them, this would be a 10x1 tensor.
- The return value must be a float Tensor of shape [n_constraints], with your evaluated score for each constraint at each index.
- Your function must take at least 1 and at most 4 atoms.
- During score attribution, the constraint score term will attribute half the score to the residue of the first atom in the atom set, and half to the last. Both atoms may be from the same residue, which will attribute that residue the full score.
As an example, here is the code for an atom-pair distance constraint:
def harmonic(atoms, params):
atoms1 = atoms[:, 0]
atoms2 = atoms[:, 1]
dist = torch.linalg.norm(atoms1 - atoms2, dim=-1)
return (dist - params[:, 0]) ** 2