Skip to content

Commit

Permalink
Training scripts (#15)
Browse files Browse the repository at this point in the history
* add training and dataset prep scripts

* lint
  • Loading branch information
jthorton authored Jun 28, 2024
1 parent f54a306 commit 9d86ac0
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 0 deletions.
1 change: 1 addition & 0 deletions naglmbis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
naglmbis
Models built with NAGL to predict MBIS properties.
"""

from . import _version

__version__ = _version.get_versions()["version"]
Expand Down
51 changes: 51 additions & 0 deletions scripts/dataset/analysis_train_val_test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# collect stats on the number of molecules, range of charges and occurances of elements in each train, val and test split of the dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
import deepchem as dc
import numpy as np

ps = Chem.SmilesParserParams()
ps.removeHs = False


def calculate_stats(dataset_name: str):
formal_charges = {}
molecular_weights = []
elements = {}
heavy_atom_count = []

# load the dataset
dataset = dc.data.DiskDataset(dataset_name)

for smiles in dataset.ids:
mol = Chem.MolFromSmiles(smiles, ps)
charges = []
for atom in mol.GetAtoms():
charges.append(atom.GetFormalCharge())
atomic_number = atom.GetAtomicNum()
if atomic_number in elements:
elements[atomic_number] += 1
else:
elements[atomic_number] = 1

total_charge = sum(charges)
if total_charge in formal_charges:
formal_charges[total_charge] += 1
else:
formal_charges[total_charge] = 1

molecular_weights.append(Descriptors.MolWt(mol))
heavy_atom_count.append(Descriptors.HeavyAtomCount(mol))

return formal_charges, molecular_weights, elements, heavy_atom_count


for dataset in ["maxmin-train", "maxmin-valid", "maxmin-test"]:
charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset)
print(f"Running {dataset} number of molecules {len(weights)}")
print("Total formal charges ", charges)
print("Total elements", atoms)
print(f"Average mol weight {np.mean(weights)} and std {np.std(weights)}")
print(
f"Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}"
)
72 changes: 72 additions & 0 deletions scripts/dataset/setup_labeled_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import h5py
import pyarrow
import pyarrow.parquet
from openff.units import unit
from collections import defaultdict
import deepchem as dc
import typing

# setup the parquet datasets using the splits generated by deepchem


# load up both files
training_db = h5py.File("TrainingSet-v1.hdf5", "r")
valid_test_db = h5py.File("ValSet-v1.hdf5", "r")


def create_parquet_dataset(
parquet_name: str,
deep_chem_dataset: dc.data.DiskDataset,
reference_datasets: typing.List[h5py.File],
):
dataset_keys = deep_chem_dataset.X
dataset_smiles = deep_chem_dataset.ids
coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"]
results = defaultdict(list)
# keep track of the number of total entries, this is each conformation expanded as a unique training point
total_records = 0
for key, smiles in zip(dataset_keys, dataset_smiles):
for dataset in reference_datasets:
if key in dataset:
data_group = dataset[key]
group_smiles = data_group["smiles"].asstr()[0]
assert group_smiles == smiles
charges = data_group["mbis-charges"][()]
dipoles = data_group["dipole"][()]
conformations = data_group["conformations"][()] * unit.angstrom
# workout how many entries we have
n_records = charges.shape[0]
total_records += n_records
for i in range(n_records):

results["smiles"].append(smiles)
results["mbis-charges"].append(charges[i])
results["dipole"].append(dipoles[i])
# make to store in bohr
results["conformation"].append(
conformations[i].m_as(unit.bohr).flatten()
)

for key, values in results.items():
assert len(values) == total_records, print(key)
columns = [results[label] for label in coloumn_names]

table = pyarrow.table(columns, coloumn_names)
pyarrow.parquet.write_table(table, parquet_name)


for file_name, dataset_name in [
("training.parquet", "maxmin-train"),
("validation.parquet", "maxmin-valid"),
("testing.parquet", "maxmin-test"),
]:
print("creating parquet for ", dataset_name)
dc_dataset = dc.data.DiskDataset(dataset_name)
create_parquet_dataset(
parquet_name=file_name,
deep_chem_dataset=dc_dataset,
reference_datasets=[training_db, valid_test_db],
)

training_db.close()
valid_test_db.close()
36 changes: 36 additions & 0 deletions scripts/dataset/split_by_deepchem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# try spliting the entire collection of data using deepchem spliters
import h5py
import deepchem as dc
import numpy as np

dataset_keys = []
smiles_ids = []
training_set = h5py.File("TrainingSet-v1.hdf5", "r")
for key, group in training_set.items():
smiles_ids.append(group["smiles"].asstr()[0])
# use the key to quickly split the datasets later
dataset_keys.append(key)
training_set.close()

# val_set = h5py.File('ValSet-v1.hdf5', 'r')
# for key, group in val_set.items():
# smiles_ids.append(group['smiles'].asstr()[0])
# dataset_keys.append(key)

# val_set.close()


print(f"The total number of unique molecules {len(smiles_ids)}")
print("Running MaxMin Splitter ...")

xs = np.array(dataset_keys)

total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids)

max_min_split = dc.splits.MaxMinSplitter()
train, validation, test = max_min_split.train_valid_test_split(
total_dataset,
train_dir="maxmin-train",
valid_dir="maxmin-valid",
test_dir="maxmin-test",
)
Loading

0 comments on commit 9d86ac0

Please sign in to comment.