From 112b15ed373463e8f66b89eb1e893055a0a3d4ee Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Fri, 28 Jun 2024 10:36:45 +0100 Subject: [PATCH] add training and dataset prep scripts --- .../dataset/analysis_train_val_test_split.py | 50 ++++ scripts/dataset/setup_labeled_data.py | 59 +++++ scripts/dataset/split_by_deepchem.py | 32 +++ scripts/training/train_model.py | 230 ++++++++++++++++++ 4 files changed, 371 insertions(+) create mode 100644 scripts/dataset/analysis_train_val_test_split.py create mode 100644 scripts/dataset/setup_labeled_data.py create mode 100644 scripts/dataset/split_by_deepchem.py create mode 100644 scripts/training/train_model.py diff --git a/scripts/dataset/analysis_train_val_test_split.py b/scripts/dataset/analysis_train_val_test_split.py new file mode 100644 index 0000000..1e80c80 --- /dev/null +++ b/scripts/dataset/analysis_train_val_test_split.py @@ -0,0 +1,50 @@ +# collect stats on the number of molecules, range of charges and occurances of elements in each train, val and test split of the dataset +from rdkit import Chem +from rdkit.Chem import Descriptors +import deepchem as dc +import numpy as np + +ps = Chem.SmilesParserParams() +ps.removeHs = False + + +def calculate_stats(dataset_name: str): + formal_charges = {} + molecular_weights = [] + elements = {} + heavy_atom_count = [] + + # load the dataset + dataset = dc.data.DiskDataset(dataset_name) + + for smiles in dataset.ids: + mol = Chem.MolFromSmiles(smiles, ps) + charges = [] + for atom in mol.GetAtoms(): + charges.append(atom.GetFormalCharge()) + atomic_number = atom.GetAtomicNum() + if atomic_number in elements: + elements[atomic_number] += 1 + else: + elements[atomic_number] = 1 + + total_charge = sum(charges) + if total_charge in formal_charges: + formal_charges[total_charge] += 1 + else: + formal_charges[total_charge] = 1 + + molecular_weights.append(Descriptors.MolWt(mol)) + heavy_atom_count.append(Descriptors.HeavyAtomCount(mol)) + + return formal_charges, molecular_weights, elements, heavy_atom_count + + +for dataset in ['maxmin-train', 'maxmin-valid', 'maxmin-test']: + charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset) + print(f'Running {dataset} number of molecules {len(weights)}') + print('Total formal charges ', charges) + print('Total elements', atoms) + print(f'Average mol weight {np.mean(weights)} and std {np.std(weights)}') + print(f'Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}') + diff --git a/scripts/dataset/setup_labeled_data.py b/scripts/dataset/setup_labeled_data.py new file mode 100644 index 0000000..986689c --- /dev/null +++ b/scripts/dataset/setup_labeled_data.py @@ -0,0 +1,59 @@ +import h5py +import pyarrow +import pyarrow.parquet +from openff.units import unit +from collections import defaultdict +import deepchem as dc +import typing + +# setup the parquet datasets using the splits generated by deepchem + + + +# load up both files +training_db = h5py.File("TrainingSet-v1.hdf5", "r") +valid_test_db = h5py.File('ValSet-v1.hdf5', 'r') + +def create_parquet_dataset(parquet_name: str, deep_chem_dataset: dc.data.DiskDataset, reference_datasets: typing.List[h5py.File]): + dataset_keys = deep_chem_dataset.X + dataset_smiles = deep_chem_dataset.ids + coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"] + results = defaultdict(list) + # keep track of the number of total entries, this is each conformation expanded as a unique training point + total_records = 0 + for key, smiles in zip(dataset_keys, dataset_smiles): + for dataset in reference_datasets: + if key in dataset: + data_group = dataset[key] + group_smiles = data_group["smiles"].asstr()[0] + assert group_smiles == smiles + charges = data_group["mbis-charges"][()] + dipoles = data_group["dipole"][()] + conformations = data_group["conformations"][()] * unit.angstrom + # workout how many entries we have + n_records = charges.shape[0] + total_records += n_records + for i in range(n_records): + + results["smiles"].append(smiles) + results["mbis-charges"].append(charges[i]) + results["dipole"].append(dipoles[i]) + # make to store in bohr + results["conformation"].append(conformations[i].m_as(unit.bohr).flatten()) + + + for key, values in results.items(): + assert len(values) == total_records, print(key) + columns = [results[label] for label in coloumn_names] + + table = pyarrow.table(columns, coloumn_names) + pyarrow.parquet.write_table(table, parquet_name) + + +for file_name, dataset_name in [('training.parquet', 'maxmin-train'), ('validation.parquet', 'maxmin-valid'), ('testing.parquet', 'maxmin-test')]: + print('creating parquet for ', dataset_name) + dc_dataset = dc.data.DiskDataset(dataset_name) + create_parquet_dataset(parquet_name=file_name, deep_chem_dataset=dc_dataset, reference_datasets=[training_db, valid_test_db]) + +training_db.close() +valid_test_db.close() \ No newline at end of file diff --git a/scripts/dataset/split_by_deepchem.py b/scripts/dataset/split_by_deepchem.py new file mode 100644 index 0000000..122b60e --- /dev/null +++ b/scripts/dataset/split_by_deepchem.py @@ -0,0 +1,32 @@ +# try spliting the entire collection of data using deepchem spliters +import h5py +import deepchem as dc +import numpy as np + +dataset_keys = [] +smiles_ids = [] +training_set = h5py.File('TrainingSet-v1.hdf5', 'r') +for key, group in training_set.items(): + smiles_ids.append(group['smiles'].asstr()[0]) + # use the key to quickly split the datasets later + dataset_keys.append(key) +training_set.close() + +# val_set = h5py.File('ValSet-v1.hdf5', 'r') +# for key, group in val_set.items(): +# smiles_ids.append(group['smiles'].asstr()[0]) +# dataset_keys.append(key) + +# val_set.close() + + +print(f'The total number of unique molecules {len(smiles_ids)}') +print('Running MaxMin Splitter ...') + +xs = np.array(dataset_keys) + +total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids) + +max_min_split = dc.splits.MaxMinSplitter() +train, validation, test = max_min_split.train_valid_test_split(total_dataset, train_dir='maxmin-train', valid_dir='maxmin-valid', test_dir='maxmin-test') + diff --git a/scripts/training/train_model.py b/scripts/training/train_model.py new file mode 100644 index 0000000..439aae8 --- /dev/null +++ b/scripts/training/train_model.py @@ -0,0 +1,230 @@ +# Test training script to make sure dipole prediction works +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +import torch +from pytorch_lightning.loggers import MLFlowLogger + +from nagl.config import Config, DataConfig, ModelConfig, OptimizerConfig +from nagl.config.data import Dataset, DipoleTarget, ReadoutTarget +from nagl.config.model import GCNConvolutionModule, ReadoutModule, Sequential +from nagl.features import ( + AtomConnectivity, + AtomFeature, + AtomicElement, + BondFeature, + AtomFeature, + register_atom_feature, + _CUSTOM_ATOM_FEATURES, +) +from nagl.training import DGLMoleculeDataModule, DGLMoleculeLightningModel +import typing +import logging +import pathlib +import pydantic +from rdkit import Chem +import dataclasses + +DEFAULT_RING_SIZES = [3, 4, 5, 6, 7, 8] + + +# define our ring membership feature +@pydantic.dataclasses.dataclass(config={"extra": pydantic.Extra.forbid}) +class AtomInRingOfSize(AtomFeature): + type: typing.Literal["ringofsize"] = "ringofsize" + ring_sizes: typing.List[pydantic.PositiveInt] = pydantic.Field( + DEFAULT_RING_SIZES, + description="The size of the ring we want to check membership of", + ) + + def __len__(self): + return len(self.ring_sizes) + + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: + ring_info: Chem.RingInfo = molecule.GetRingInfo() + + return torch.vstack( + [ + torch.Tensor( + [ + int(ring_info.IsAtomInRingOfSize(atom.GetIdx(), ring_size)) + for ring_size in self.ring_sizes + ] + ) + for atom in molecule.GetAtoms() + ] + ) + + +def configure_model( + atom_features: typing.List[AtomFeature], + bond_features: typing.List[BondFeature], + n_gcn_layers: int, + n_gcn_hidden_features: int, + n_am1_layers: int, + n_am1_hidden_features: int, +) -> ModelConfig: + return ModelConfig( + atom_features=atom_features, + bond_features=bond_features, + convolution=GCNConvolutionModule( + type="SAGEConv", + hidden_feats=[n_gcn_hidden_features] * n_gcn_layers, + activation=["ReLU"] * n_gcn_layers, + ), + readouts={ + "mbis-charges": ReadoutModule( + pooling="atom", + forward=Sequential( + hidden_feats=[n_am1_hidden_features] * n_am1_layers + [2], + activation=["ReLU"] * n_am1_layers + ["Identity"], + ), + postprocess="charges", + ) + }, + ) + + +def configure_data() -> DataConfig: + return DataConfig( + training=Dataset( + sources=["../datasets/training.parquet"], + # The 'column' must match one of the label columns in the parquet + # table that was create during stage 000. + # The 'readout' column should correspond to one our or model readout + # keys. + # denom for charge in e and dipole in e*bohr 0.1D~ + targets=[ + ReadoutTarget( + column="mbis-charges", + readout="mbis-charges", + metric="rmse", + denominator=0.02, + ), + DipoleTarget( + metric="rmse", + dipole_column="dipole", + conformation_column="conformation", + charge_label="mbis-charges", + denominator=0.04, + ), + ], + batch_size=250, + ), + validation=Dataset( + sources=["../datasets/validation.parquet"], + targets=[ + ReadoutTarget( + column="mbis-charges", + readout="mbis-charges", + metric="rmse", + denominator=0.02, + ), + DipoleTarget( + metric="rmse", + dipole_column="dipole", + conformation_column="conformation", + charge_label="mbis-charges", + denominator=0.04, + ), + ], + ), + test=Dataset( + sources=["../datasets/testing.parquet"], + targets=[ + ReadoutTarget( + column="mbis-charges", + readout="mbis-charges", + metric="rmse", + denominator=0.02, + ), + DipoleTarget( + metric="rmse", + dipole_column="dipole", + conformation_column="conformation", + charge_label="mbis-charges", + denominator=0.04, + ), + ], + ), + ) + + +def configure_optimizer(lr: float) -> OptimizerConfig: + return OptimizerConfig(type="Adam", lr=lr) + + +def main(): + logging.basicConfig(level=logging.INFO) + output_dir = pathlib.Path("001-train-charge-model-small-mols") + + register_atom_feature(AtomInRingOfSize) + print(_CUSTOM_ATOM_FEATURES) + # Configure our model, data sets, and optimizer. + model_config = configure_model( + atom_features=[ + AtomicElement(values=["H", "C", "N", "O", "F", "P", "S", "Cl", "Br"]), + AtomConnectivity(), + dataclasses.asdict(AtomInRingOfSize()), + ], + bond_features=[], + n_gcn_layers=5, + n_gcn_hidden_features=128, + n_am1_layers=2, + n_am1_hidden_features=64, + ) + data_config = configure_data() + + optimizer_config = configure_optimizer(0.001) + + # Define the model and lightning data module that will contain the train, val, + # and test dataloaders if specified in ``data_config``. + config = Config(model=model_config, data=data_config, optimizer=optimizer_config) + + model = DGLMoleculeLightningModel(config) + model.to_yaml("charge-dipole-v1.yaml") + print("Model", model) + + # The 'cache_dir' will store the fully featurized molecules so we don't need to + # re-compute these each to we adjust a hyperparameter for example. + data = DGLMoleculeDataModule(config, cache_dir=output_dir / "feature-cache") + + # Define an MLFlow experiment to store the outputs of training this model. This + # Will include the usual statistics as well as useful artifacts highlighting + # the models weak spots. + logger = MLFlowLogger( + experiment_name="mbis-charge-dipole-model-small-mols-1000", + save_dir=str(output_dir / "mlruns"), + log_model="all", + ) + + # The MLFlow UI can be opened by running: + # + # mlflow ui --backend-store-uri ./001-train-charge-model/mlruns \ + # --default-artifact-root ./001-train-charge-model/mlruns + # + + # Train the model + n_epochs = 1000 + + n_gpus = 0 if not torch.cuda.is_available() else 1 + print(f"Using {n_gpus} GPUs") + + model_checkpoint = ModelCheckpoint(monitor='val/loss', dirpath=output_dir.joinpath('')) + trainer = pl.Trainer( + accelerator='cpu', + # devices=n_gpus, + min_epochs=n_epochs, + max_epochs=n_epochs, + logger=logger, + log_every_n_steps=50, + callbacks=[model_checkpoint] + ) + + trainer.fit(model, datamodule=data) + trainer.test(model, datamodule=data) + + print(model_checkpoint.best_model_path) + + +if __name__ == "__main__": + main()