Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jthorton committed Jun 28, 2024
1 parent 112b15e commit c535767
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 27 deletions.
1 change: 1 addition & 0 deletions naglmbis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
naglmbis
Models built with NAGL to predict MBIS properties.
"""

from . import _version

__version__ = _version.get_versions()["version"]
Expand Down
19 changes: 10 additions & 9 deletions scripts/dataset/analysis_train_val_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def calculate_stats(dataset_name: str):

# load the dataset
dataset = dc.data.DiskDataset(dataset_name)

for smiles in dataset.ids:
mol = Chem.MolFromSmiles(smiles, ps)
charges = []
Expand All @@ -27,7 +27,7 @@ def calculate_stats(dataset_name: str):
elements[atomic_number] += 1
else:
elements[atomic_number] = 1

total_charge = sum(charges)
if total_charge in formal_charges:
formal_charges[total_charge] += 1
Expand All @@ -40,11 +40,12 @@ def calculate_stats(dataset_name: str):
return formal_charges, molecular_weights, elements, heavy_atom_count


for dataset in ['maxmin-train', 'maxmin-valid', 'maxmin-test']:
for dataset in ["maxmin-train", "maxmin-valid", "maxmin-test"]:
charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset)
print(f'Running {dataset} number of molecules {len(weights)}')
print('Total formal charges ', charges)
print('Total elements', atoms)
print(f'Average mol weight {np.mean(weights)} and std {np.std(weights)}')
print(f'Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}')

print(f"Running {dataset} number of molecules {len(weights)}")
print("Total formal charges ", charges)
print("Total elements", atoms)
print(f"Average mol weight {np.mean(weights)} and std {np.std(weights)}")
print(
f"Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}"
)
31 changes: 22 additions & 9 deletions scripts/dataset/setup_labeled_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
# setup the parquet datasets using the splits generated by deepchem



# load up both files
training_db = h5py.File("TrainingSet-v1.hdf5", "r")
valid_test_db = h5py.File('ValSet-v1.hdf5', 'r')
valid_test_db = h5py.File("ValSet-v1.hdf5", "r")


def create_parquet_dataset(parquet_name: str, deep_chem_dataset: dc.data.DiskDataset, reference_datasets: typing.List[h5py.File]):
def create_parquet_dataset(
parquet_name: str,
deep_chem_dataset: dc.data.DiskDataset,
reference_datasets: typing.List[h5py.File],
):
dataset_keys = deep_chem_dataset.X
dataset_smiles = deep_chem_dataset.ids
coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"]
Expand All @@ -39,8 +43,9 @@ def create_parquet_dataset(parquet_name: str, deep_chem_dataset: dc.data.DiskDat
results["mbis-charges"].append(charges[i])
results["dipole"].append(dipoles[i])
# make to store in bohr
results["conformation"].append(conformations[i].m_as(unit.bohr).flatten())

results["conformation"].append(
conformations[i].m_as(unit.bohr).flatten()
)

for key, values in results.items():
assert len(values) == total_records, print(key)
Expand All @@ -50,10 +55,18 @@ def create_parquet_dataset(parquet_name: str, deep_chem_dataset: dc.data.DiskDat
pyarrow.parquet.write_table(table, parquet_name)


for file_name, dataset_name in [('training.parquet', 'maxmin-train'), ('validation.parquet', 'maxmin-valid'), ('testing.parquet', 'maxmin-test')]:
print('creating parquet for ', dataset_name)
for file_name, dataset_name in [
("training.parquet", "maxmin-train"),
("validation.parquet", "maxmin-valid"),
("testing.parquet", "maxmin-test"),
]:
print("creating parquet for ", dataset_name)
dc_dataset = dc.data.DiskDataset(dataset_name)
create_parquet_dataset(parquet_name=file_name, deep_chem_dataset=dc_dataset, reference_datasets=[training_db, valid_test_db])
create_parquet_dataset(
parquet_name=file_name,
deep_chem_dataset=dc_dataset,
reference_datasets=[training_db, valid_test_db],
)

training_db.close()
valid_test_db.close()
valid_test_db.close()
16 changes: 10 additions & 6 deletions scripts/dataset/split_by_deepchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

dataset_keys = []
smiles_ids = []
training_set = h5py.File('TrainingSet-v1.hdf5', 'r')
training_set = h5py.File("TrainingSet-v1.hdf5", "r")
for key, group in training_set.items():
smiles_ids.append(group['smiles'].asstr()[0])
smiles_ids.append(group["smiles"].asstr()[0])
# use the key to quickly split the datasets later
dataset_keys.append(key)
training_set.close()
Expand All @@ -20,13 +20,17 @@
# val_set.close()


print(f'The total number of unique molecules {len(smiles_ids)}')
print('Running MaxMin Splitter ...')
print(f"The total number of unique molecules {len(smiles_ids)}")
print("Running MaxMin Splitter ...")

xs = np.array(dataset_keys)

total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids)

max_min_split = dc.splits.MaxMinSplitter()
train, validation, test = max_min_split.train_valid_test_split(total_dataset, train_dir='maxmin-train', valid_dir='maxmin-valid', test_dir='maxmin-test')

train, validation, test = max_min_split.train_valid_test_split(
total_dataset,
train_dir="maxmin-train",
valid_dir="maxmin-valid",
test_dir="maxmin-test",
)
8 changes: 5 additions & 3 deletions scripts/training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,17 @@ def main():
n_gpus = 0 if not torch.cuda.is_available() else 1
print(f"Using {n_gpus} GPUs")

model_checkpoint = ModelCheckpoint(monitor='val/loss', dirpath=output_dir.joinpath(''))
model_checkpoint = ModelCheckpoint(
monitor="val/loss", dirpath=output_dir.joinpath("")
)
trainer = pl.Trainer(
accelerator='cpu',
accelerator="cpu",
# devices=n_gpus,
min_epochs=n_epochs,
max_epochs=n_epochs,
logger=logger,
log_every_n_steps=50,
callbacks=[model_checkpoint]
callbacks=[model_checkpoint],
)

trainer.fit(model, datamodule=data)
Expand Down

0 comments on commit c535767

Please sign in to comment.