Skip to content

Commit

Permalink
add training and dataset prep scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jthorton committed Jun 28, 2024
1 parent f54a306 commit 112b15e
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 0 deletions.
50 changes: 50 additions & 0 deletions scripts/dataset/analysis_train_val_test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# collect stats on the number of molecules, range of charges and occurances of elements in each train, val and test split of the dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
import deepchem as dc
import numpy as np

ps = Chem.SmilesParserParams()
ps.removeHs = False


def calculate_stats(dataset_name: str):
formal_charges = {}
molecular_weights = []
elements = {}
heavy_atom_count = []

# load the dataset
dataset = dc.data.DiskDataset(dataset_name)

for smiles in dataset.ids:
mol = Chem.MolFromSmiles(smiles, ps)
charges = []
for atom in mol.GetAtoms():
charges.append(atom.GetFormalCharge())
atomic_number = atom.GetAtomicNum()
if atomic_number in elements:
elements[atomic_number] += 1
else:
elements[atomic_number] = 1

total_charge = sum(charges)
if total_charge in formal_charges:
formal_charges[total_charge] += 1
else:
formal_charges[total_charge] = 1

molecular_weights.append(Descriptors.MolWt(mol))
heavy_atom_count.append(Descriptors.HeavyAtomCount(mol))

return formal_charges, molecular_weights, elements, heavy_atom_count


for dataset in ['maxmin-train', 'maxmin-valid', 'maxmin-test']:
charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset)
print(f'Running {dataset} number of molecules {len(weights)}')
print('Total formal charges ', charges)
print('Total elements', atoms)
print(f'Average mol weight {np.mean(weights)} and std {np.std(weights)}')
print(f'Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}')

59 changes: 59 additions & 0 deletions scripts/dataset/setup_labeled_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import h5py
import pyarrow
import pyarrow.parquet
from openff.units import unit
from collections import defaultdict
import deepchem as dc
import typing

# setup the parquet datasets using the splits generated by deepchem



# load up both files
training_db = h5py.File("TrainingSet-v1.hdf5", "r")
valid_test_db = h5py.File('ValSet-v1.hdf5', 'r')

def create_parquet_dataset(parquet_name: str, deep_chem_dataset: dc.data.DiskDataset, reference_datasets: typing.List[h5py.File]):
dataset_keys = deep_chem_dataset.X
dataset_smiles = deep_chem_dataset.ids
coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"]
results = defaultdict(list)
# keep track of the number of total entries, this is each conformation expanded as a unique training point
total_records = 0
for key, smiles in zip(dataset_keys, dataset_smiles):
for dataset in reference_datasets:
if key in dataset:
data_group = dataset[key]
group_smiles = data_group["smiles"].asstr()[0]
assert group_smiles == smiles
charges = data_group["mbis-charges"][()]
dipoles = data_group["dipole"][()]
conformations = data_group["conformations"][()] * unit.angstrom
# workout how many entries we have
n_records = charges.shape[0]
total_records += n_records
for i in range(n_records):

results["smiles"].append(smiles)
results["mbis-charges"].append(charges[i])
results["dipole"].append(dipoles[i])
# make to store in bohr
results["conformation"].append(conformations[i].m_as(unit.bohr).flatten())


for key, values in results.items():
assert len(values) == total_records, print(key)
columns = [results[label] for label in coloumn_names]

table = pyarrow.table(columns, coloumn_names)
pyarrow.parquet.write_table(table, parquet_name)


for file_name, dataset_name in [('training.parquet', 'maxmin-train'), ('validation.parquet', 'maxmin-valid'), ('testing.parquet', 'maxmin-test')]:
print('creating parquet for ', dataset_name)
dc_dataset = dc.data.DiskDataset(dataset_name)
create_parquet_dataset(parquet_name=file_name, deep_chem_dataset=dc_dataset, reference_datasets=[training_db, valid_test_db])

training_db.close()
valid_test_db.close()
32 changes: 32 additions & 0 deletions scripts/dataset/split_by_deepchem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# try spliting the entire collection of data using deepchem spliters
import h5py
import deepchem as dc
import numpy as np

dataset_keys = []
smiles_ids = []
training_set = h5py.File('TrainingSet-v1.hdf5', 'r')
for key, group in training_set.items():
smiles_ids.append(group['smiles'].asstr()[0])
# use the key to quickly split the datasets later
dataset_keys.append(key)
training_set.close()

# val_set = h5py.File('ValSet-v1.hdf5', 'r')
# for key, group in val_set.items():
# smiles_ids.append(group['smiles'].asstr()[0])
# dataset_keys.append(key)

# val_set.close()


print(f'The total number of unique molecules {len(smiles_ids)}')
print('Running MaxMin Splitter ...')

xs = np.array(dataset_keys)

total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids)

max_min_split = dc.splits.MaxMinSplitter()
train, validation, test = max_min_split.train_valid_test_split(total_dataset, train_dir='maxmin-train', valid_dir='maxmin-valid', test_dir='maxmin-test')

230 changes: 230 additions & 0 deletions scripts/training/train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Test training script to make sure dipole prediction works
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from pytorch_lightning.loggers import MLFlowLogger

from nagl.config import Config, DataConfig, ModelConfig, OptimizerConfig
from nagl.config.data import Dataset, DipoleTarget, ReadoutTarget
from nagl.config.model import GCNConvolutionModule, ReadoutModule, Sequential
from nagl.features import (
AtomConnectivity,
AtomFeature,
AtomicElement,
BondFeature,
AtomFeature,
register_atom_feature,
_CUSTOM_ATOM_FEATURES,
)
from nagl.training import DGLMoleculeDataModule, DGLMoleculeLightningModel
import typing
import logging
import pathlib
import pydantic
from rdkit import Chem
import dataclasses

DEFAULT_RING_SIZES = [3, 4, 5, 6, 7, 8]


# define our ring membership feature
@pydantic.dataclasses.dataclass(config={"extra": pydantic.Extra.forbid})
class AtomInRingOfSize(AtomFeature):
type: typing.Literal["ringofsize"] = "ringofsize"
ring_sizes: typing.List[pydantic.PositiveInt] = pydantic.Field(
DEFAULT_RING_SIZES,
description="The size of the ring we want to check membership of",
)

def __len__(self):
return len(self.ring_sizes)

def __call__(self, molecule: Chem.Mol) -> torch.Tensor:
ring_info: Chem.RingInfo = molecule.GetRingInfo()

return torch.vstack(
[
torch.Tensor(
[
int(ring_info.IsAtomInRingOfSize(atom.GetIdx(), ring_size))
for ring_size in self.ring_sizes
]
)
for atom in molecule.GetAtoms()
]
)


def configure_model(
atom_features: typing.List[AtomFeature],
bond_features: typing.List[BondFeature],
n_gcn_layers: int,
n_gcn_hidden_features: int,
n_am1_layers: int,
n_am1_hidden_features: int,
) -> ModelConfig:
return ModelConfig(
atom_features=atom_features,
bond_features=bond_features,
convolution=GCNConvolutionModule(
type="SAGEConv",
hidden_feats=[n_gcn_hidden_features] * n_gcn_layers,
activation=["ReLU"] * n_gcn_layers,
),
readouts={
"mbis-charges": ReadoutModule(
pooling="atom",
forward=Sequential(
hidden_feats=[n_am1_hidden_features] * n_am1_layers + [2],
activation=["ReLU"] * n_am1_layers + ["Identity"],
),
postprocess="charges",
)
},
)


def configure_data() -> DataConfig:
return DataConfig(
training=Dataset(
sources=["../datasets/training.parquet"],
# The 'column' must match one of the label columns in the parquet
# table that was create during stage 000.
# The 'readout' column should correspond to one our or model readout
# keys.
# denom for charge in e and dipole in e*bohr 0.1D~
targets=[
ReadoutTarget(
column="mbis-charges",
readout="mbis-charges",
metric="rmse",
denominator=0.02,
),
DipoleTarget(
metric="rmse",
dipole_column="dipole",
conformation_column="conformation",
charge_label="mbis-charges",
denominator=0.04,
),
],
batch_size=250,
),
validation=Dataset(
sources=["../datasets/validation.parquet"],
targets=[
ReadoutTarget(
column="mbis-charges",
readout="mbis-charges",
metric="rmse",
denominator=0.02,
),
DipoleTarget(
metric="rmse",
dipole_column="dipole",
conformation_column="conformation",
charge_label="mbis-charges",
denominator=0.04,
),
],
),
test=Dataset(
sources=["../datasets/testing.parquet"],
targets=[
ReadoutTarget(
column="mbis-charges",
readout="mbis-charges",
metric="rmse",
denominator=0.02,
),
DipoleTarget(
metric="rmse",
dipole_column="dipole",
conformation_column="conformation",
charge_label="mbis-charges",
denominator=0.04,
),
],
),
)


def configure_optimizer(lr: float) -> OptimizerConfig:
return OptimizerConfig(type="Adam", lr=lr)


def main():
logging.basicConfig(level=logging.INFO)
output_dir = pathlib.Path("001-train-charge-model-small-mols")

register_atom_feature(AtomInRingOfSize)
print(_CUSTOM_ATOM_FEATURES)
# Configure our model, data sets, and optimizer.
model_config = configure_model(
atom_features=[
AtomicElement(values=["H", "C", "N", "O", "F", "P", "S", "Cl", "Br"]),
AtomConnectivity(),
dataclasses.asdict(AtomInRingOfSize()),
],
bond_features=[],
n_gcn_layers=5,
n_gcn_hidden_features=128,
n_am1_layers=2,
n_am1_hidden_features=64,
)
data_config = configure_data()

optimizer_config = configure_optimizer(0.001)

# Define the model and lightning data module that will contain the train, val,
# and test dataloaders if specified in ``data_config``.
config = Config(model=model_config, data=data_config, optimizer=optimizer_config)

model = DGLMoleculeLightningModel(config)
model.to_yaml("charge-dipole-v1.yaml")
print("Model", model)

# The 'cache_dir' will store the fully featurized molecules so we don't need to
# re-compute these each to we adjust a hyperparameter for example.
data = DGLMoleculeDataModule(config, cache_dir=output_dir / "feature-cache")

# Define an MLFlow experiment to store the outputs of training this model. This
# Will include the usual statistics as well as useful artifacts highlighting
# the models weak spots.
logger = MLFlowLogger(
experiment_name="mbis-charge-dipole-model-small-mols-1000",
save_dir=str(output_dir / "mlruns"),
log_model="all",
)

# The MLFlow UI can be opened by running:
#
# mlflow ui --backend-store-uri ./001-train-charge-model/mlruns \
# --default-artifact-root ./001-train-charge-model/mlruns
#

# Train the model
n_epochs = 1000

n_gpus = 0 if not torch.cuda.is_available() else 1
print(f"Using {n_gpus} GPUs")

model_checkpoint = ModelCheckpoint(monitor='val/loss', dirpath=output_dir.joinpath(''))
trainer = pl.Trainer(
accelerator='cpu',
# devices=n_gpus,
min_epochs=n_epochs,
max_epochs=n_epochs,
logger=logger,
log_every_n_steps=50,
callbacks=[model_checkpoint]
)

trainer.fit(model, datamodule=data)
trainer.test(model, datamodule=data)

print(model_checkpoint.best_model_path)


if __name__ == "__main__":
main()

0 comments on commit 112b15e

Please sign in to comment.