-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add training and dataset prep scripts
- Loading branch information
Showing
4 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# collect stats on the number of molecules, range of charges and occurances of elements in each train, val and test split of the dataset | ||
from rdkit import Chem | ||
from rdkit.Chem import Descriptors | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
ps = Chem.SmilesParserParams() | ||
ps.removeHs = False | ||
|
||
|
||
def calculate_stats(dataset_name: str): | ||
formal_charges = {} | ||
molecular_weights = [] | ||
elements = {} | ||
heavy_atom_count = [] | ||
|
||
# load the dataset | ||
dataset = dc.data.DiskDataset(dataset_name) | ||
|
||
for smiles in dataset.ids: | ||
mol = Chem.MolFromSmiles(smiles, ps) | ||
charges = [] | ||
for atom in mol.GetAtoms(): | ||
charges.append(atom.GetFormalCharge()) | ||
atomic_number = atom.GetAtomicNum() | ||
if atomic_number in elements: | ||
elements[atomic_number] += 1 | ||
else: | ||
elements[atomic_number] = 1 | ||
|
||
total_charge = sum(charges) | ||
if total_charge in formal_charges: | ||
formal_charges[total_charge] += 1 | ||
else: | ||
formal_charges[total_charge] = 1 | ||
|
||
molecular_weights.append(Descriptors.MolWt(mol)) | ||
heavy_atom_count.append(Descriptors.HeavyAtomCount(mol)) | ||
|
||
return formal_charges, molecular_weights, elements, heavy_atom_count | ||
|
||
|
||
for dataset in ['maxmin-train', 'maxmin-valid', 'maxmin-test']: | ||
charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset) | ||
print(f'Running {dataset} number of molecules {len(weights)}') | ||
print('Total formal charges ', charges) | ||
print('Total elements', atoms) | ||
print(f'Average mol weight {np.mean(weights)} and std {np.std(weights)}') | ||
print(f'Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import h5py | ||
import pyarrow | ||
import pyarrow.parquet | ||
from openff.units import unit | ||
from collections import defaultdict | ||
import deepchem as dc | ||
import typing | ||
|
||
# setup the parquet datasets using the splits generated by deepchem | ||
|
||
|
||
|
||
# load up both files | ||
training_db = h5py.File("TrainingSet-v1.hdf5", "r") | ||
valid_test_db = h5py.File('ValSet-v1.hdf5', 'r') | ||
|
||
def create_parquet_dataset(parquet_name: str, deep_chem_dataset: dc.data.DiskDataset, reference_datasets: typing.List[h5py.File]): | ||
dataset_keys = deep_chem_dataset.X | ||
dataset_smiles = deep_chem_dataset.ids | ||
coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"] | ||
results = defaultdict(list) | ||
# keep track of the number of total entries, this is each conformation expanded as a unique training point | ||
total_records = 0 | ||
for key, smiles in zip(dataset_keys, dataset_smiles): | ||
for dataset in reference_datasets: | ||
if key in dataset: | ||
data_group = dataset[key] | ||
group_smiles = data_group["smiles"].asstr()[0] | ||
assert group_smiles == smiles | ||
charges = data_group["mbis-charges"][()] | ||
dipoles = data_group["dipole"][()] | ||
conformations = data_group["conformations"][()] * unit.angstrom | ||
# workout how many entries we have | ||
n_records = charges.shape[0] | ||
total_records += n_records | ||
for i in range(n_records): | ||
|
||
results["smiles"].append(smiles) | ||
results["mbis-charges"].append(charges[i]) | ||
results["dipole"].append(dipoles[i]) | ||
# make to store in bohr | ||
results["conformation"].append(conformations[i].m_as(unit.bohr).flatten()) | ||
|
||
|
||
for key, values in results.items(): | ||
assert len(values) == total_records, print(key) | ||
columns = [results[label] for label in coloumn_names] | ||
|
||
table = pyarrow.table(columns, coloumn_names) | ||
pyarrow.parquet.write_table(table, parquet_name) | ||
|
||
|
||
for file_name, dataset_name in [('training.parquet', 'maxmin-train'), ('validation.parquet', 'maxmin-valid'), ('testing.parquet', 'maxmin-test')]: | ||
print('creating parquet for ', dataset_name) | ||
dc_dataset = dc.data.DiskDataset(dataset_name) | ||
create_parquet_dataset(parquet_name=file_name, deep_chem_dataset=dc_dataset, reference_datasets=[training_db, valid_test_db]) | ||
|
||
training_db.close() | ||
valid_test_db.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# try spliting the entire collection of data using deepchem spliters | ||
import h5py | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
dataset_keys = [] | ||
smiles_ids = [] | ||
training_set = h5py.File('TrainingSet-v1.hdf5', 'r') | ||
for key, group in training_set.items(): | ||
smiles_ids.append(group['smiles'].asstr()[0]) | ||
# use the key to quickly split the datasets later | ||
dataset_keys.append(key) | ||
training_set.close() | ||
|
||
# val_set = h5py.File('ValSet-v1.hdf5', 'r') | ||
# for key, group in val_set.items(): | ||
# smiles_ids.append(group['smiles'].asstr()[0]) | ||
# dataset_keys.append(key) | ||
|
||
# val_set.close() | ||
|
||
|
||
print(f'The total number of unique molecules {len(smiles_ids)}') | ||
print('Running MaxMin Splitter ...') | ||
|
||
xs = np.array(dataset_keys) | ||
|
||
total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids) | ||
|
||
max_min_split = dc.splits.MaxMinSplitter() | ||
train, validation, test = max_min_split.train_valid_test_split(total_dataset, train_dir='maxmin-train', valid_dir='maxmin-valid', test_dir='maxmin-test') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# Test training script to make sure dipole prediction works | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
import torch | ||
from pytorch_lightning.loggers import MLFlowLogger | ||
|
||
from nagl.config import Config, DataConfig, ModelConfig, OptimizerConfig | ||
from nagl.config.data import Dataset, DipoleTarget, ReadoutTarget | ||
from nagl.config.model import GCNConvolutionModule, ReadoutModule, Sequential | ||
from nagl.features import ( | ||
AtomConnectivity, | ||
AtomFeature, | ||
AtomicElement, | ||
BondFeature, | ||
AtomFeature, | ||
register_atom_feature, | ||
_CUSTOM_ATOM_FEATURES, | ||
) | ||
from nagl.training import DGLMoleculeDataModule, DGLMoleculeLightningModel | ||
import typing | ||
import logging | ||
import pathlib | ||
import pydantic | ||
from rdkit import Chem | ||
import dataclasses | ||
|
||
DEFAULT_RING_SIZES = [3, 4, 5, 6, 7, 8] | ||
|
||
|
||
# define our ring membership feature | ||
@pydantic.dataclasses.dataclass(config={"extra": pydantic.Extra.forbid}) | ||
class AtomInRingOfSize(AtomFeature): | ||
type: typing.Literal["ringofsize"] = "ringofsize" | ||
ring_sizes: typing.List[pydantic.PositiveInt] = pydantic.Field( | ||
DEFAULT_RING_SIZES, | ||
description="The size of the ring we want to check membership of", | ||
) | ||
|
||
def __len__(self): | ||
return len(self.ring_sizes) | ||
|
||
def __call__(self, molecule: Chem.Mol) -> torch.Tensor: | ||
ring_info: Chem.RingInfo = molecule.GetRingInfo() | ||
|
||
return torch.vstack( | ||
[ | ||
torch.Tensor( | ||
[ | ||
int(ring_info.IsAtomInRingOfSize(atom.GetIdx(), ring_size)) | ||
for ring_size in self.ring_sizes | ||
] | ||
) | ||
for atom in molecule.GetAtoms() | ||
] | ||
) | ||
|
||
|
||
def configure_model( | ||
atom_features: typing.List[AtomFeature], | ||
bond_features: typing.List[BondFeature], | ||
n_gcn_layers: int, | ||
n_gcn_hidden_features: int, | ||
n_am1_layers: int, | ||
n_am1_hidden_features: int, | ||
) -> ModelConfig: | ||
return ModelConfig( | ||
atom_features=atom_features, | ||
bond_features=bond_features, | ||
convolution=GCNConvolutionModule( | ||
type="SAGEConv", | ||
hidden_feats=[n_gcn_hidden_features] * n_gcn_layers, | ||
activation=["ReLU"] * n_gcn_layers, | ||
), | ||
readouts={ | ||
"mbis-charges": ReadoutModule( | ||
pooling="atom", | ||
forward=Sequential( | ||
hidden_feats=[n_am1_hidden_features] * n_am1_layers + [2], | ||
activation=["ReLU"] * n_am1_layers + ["Identity"], | ||
), | ||
postprocess="charges", | ||
) | ||
}, | ||
) | ||
|
||
|
||
def configure_data() -> DataConfig: | ||
return DataConfig( | ||
training=Dataset( | ||
sources=["../datasets/training.parquet"], | ||
# The 'column' must match one of the label columns in the parquet | ||
# table that was create during stage 000. | ||
# The 'readout' column should correspond to one our or model readout | ||
# keys. | ||
# denom for charge in e and dipole in e*bohr 0.1D~ | ||
targets=[ | ||
ReadoutTarget( | ||
column="mbis-charges", | ||
readout="mbis-charges", | ||
metric="rmse", | ||
denominator=0.02, | ||
), | ||
DipoleTarget( | ||
metric="rmse", | ||
dipole_column="dipole", | ||
conformation_column="conformation", | ||
charge_label="mbis-charges", | ||
denominator=0.04, | ||
), | ||
], | ||
batch_size=250, | ||
), | ||
validation=Dataset( | ||
sources=["../datasets/validation.parquet"], | ||
targets=[ | ||
ReadoutTarget( | ||
column="mbis-charges", | ||
readout="mbis-charges", | ||
metric="rmse", | ||
denominator=0.02, | ||
), | ||
DipoleTarget( | ||
metric="rmse", | ||
dipole_column="dipole", | ||
conformation_column="conformation", | ||
charge_label="mbis-charges", | ||
denominator=0.04, | ||
), | ||
], | ||
), | ||
test=Dataset( | ||
sources=["../datasets/testing.parquet"], | ||
targets=[ | ||
ReadoutTarget( | ||
column="mbis-charges", | ||
readout="mbis-charges", | ||
metric="rmse", | ||
denominator=0.02, | ||
), | ||
DipoleTarget( | ||
metric="rmse", | ||
dipole_column="dipole", | ||
conformation_column="conformation", | ||
charge_label="mbis-charges", | ||
denominator=0.04, | ||
), | ||
], | ||
), | ||
) | ||
|
||
|
||
def configure_optimizer(lr: float) -> OptimizerConfig: | ||
return OptimizerConfig(type="Adam", lr=lr) | ||
|
||
|
||
def main(): | ||
logging.basicConfig(level=logging.INFO) | ||
output_dir = pathlib.Path("001-train-charge-model-small-mols") | ||
|
||
register_atom_feature(AtomInRingOfSize) | ||
print(_CUSTOM_ATOM_FEATURES) | ||
# Configure our model, data sets, and optimizer. | ||
model_config = configure_model( | ||
atom_features=[ | ||
AtomicElement(values=["H", "C", "N", "O", "F", "P", "S", "Cl", "Br"]), | ||
AtomConnectivity(), | ||
dataclasses.asdict(AtomInRingOfSize()), | ||
], | ||
bond_features=[], | ||
n_gcn_layers=5, | ||
n_gcn_hidden_features=128, | ||
n_am1_layers=2, | ||
n_am1_hidden_features=64, | ||
) | ||
data_config = configure_data() | ||
|
||
optimizer_config = configure_optimizer(0.001) | ||
|
||
# Define the model and lightning data module that will contain the train, val, | ||
# and test dataloaders if specified in ``data_config``. | ||
config = Config(model=model_config, data=data_config, optimizer=optimizer_config) | ||
|
||
model = DGLMoleculeLightningModel(config) | ||
model.to_yaml("charge-dipole-v1.yaml") | ||
print("Model", model) | ||
|
||
# The 'cache_dir' will store the fully featurized molecules so we don't need to | ||
# re-compute these each to we adjust a hyperparameter for example. | ||
data = DGLMoleculeDataModule(config, cache_dir=output_dir / "feature-cache") | ||
|
||
# Define an MLFlow experiment to store the outputs of training this model. This | ||
# Will include the usual statistics as well as useful artifacts highlighting | ||
# the models weak spots. | ||
logger = MLFlowLogger( | ||
experiment_name="mbis-charge-dipole-model-small-mols-1000", | ||
save_dir=str(output_dir / "mlruns"), | ||
log_model="all", | ||
) | ||
|
||
# The MLFlow UI can be opened by running: | ||
# | ||
# mlflow ui --backend-store-uri ./001-train-charge-model/mlruns \ | ||
# --default-artifact-root ./001-train-charge-model/mlruns | ||
# | ||
|
||
# Train the model | ||
n_epochs = 1000 | ||
|
||
n_gpus = 0 if not torch.cuda.is_available() else 1 | ||
print(f"Using {n_gpus} GPUs") | ||
|
||
model_checkpoint = ModelCheckpoint(monitor='val/loss', dirpath=output_dir.joinpath('')) | ||
trainer = pl.Trainer( | ||
accelerator='cpu', | ||
# devices=n_gpus, | ||
min_epochs=n_epochs, | ||
max_epochs=n_epochs, | ||
logger=logger, | ||
log_every_n_steps=50, | ||
callbacks=[model_checkpoint] | ||
) | ||
|
||
trainer.fit(model, datamodule=data) | ||
trainer.test(model, datamodule=data) | ||
|
||
print(model_checkpoint.best_model_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |