-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add training and dataset prep scripts * lint
- Loading branch information
Showing
5 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# collect stats on the number of molecules, range of charges and occurances of elements in each train, val and test split of the dataset | ||
from rdkit import Chem | ||
from rdkit.Chem import Descriptors | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
ps = Chem.SmilesParserParams() | ||
ps.removeHs = False | ||
|
||
|
||
def calculate_stats(dataset_name: str): | ||
formal_charges = {} | ||
molecular_weights = [] | ||
elements = {} | ||
heavy_atom_count = [] | ||
|
||
# load the dataset | ||
dataset = dc.data.DiskDataset(dataset_name) | ||
|
||
for smiles in dataset.ids: | ||
mol = Chem.MolFromSmiles(smiles, ps) | ||
charges = [] | ||
for atom in mol.GetAtoms(): | ||
charges.append(atom.GetFormalCharge()) | ||
atomic_number = atom.GetAtomicNum() | ||
if atomic_number in elements: | ||
elements[atomic_number] += 1 | ||
else: | ||
elements[atomic_number] = 1 | ||
|
||
total_charge = sum(charges) | ||
if total_charge in formal_charges: | ||
formal_charges[total_charge] += 1 | ||
else: | ||
formal_charges[total_charge] = 1 | ||
|
||
molecular_weights.append(Descriptors.MolWt(mol)) | ||
heavy_atom_count.append(Descriptors.HeavyAtomCount(mol)) | ||
|
||
return formal_charges, molecular_weights, elements, heavy_atom_count | ||
|
||
|
||
for dataset in ["maxmin-train", "maxmin-valid", "maxmin-test"]: | ||
charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset) | ||
print(f"Running {dataset} number of molecules {len(weights)}") | ||
print("Total formal charges ", charges) | ||
print("Total elements", atoms) | ||
print(f"Average mol weight {np.mean(weights)} and std {np.std(weights)}") | ||
print( | ||
f"Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import h5py | ||
import pyarrow | ||
import pyarrow.parquet | ||
from openff.units import unit | ||
from collections import defaultdict | ||
import deepchem as dc | ||
import typing | ||
|
||
# setup the parquet datasets using the splits generated by deepchem | ||
|
||
|
||
# load up both files | ||
training_db = h5py.File("TrainingSet-v1.hdf5", "r") | ||
valid_test_db = h5py.File("ValSet-v1.hdf5", "r") | ||
|
||
|
||
def create_parquet_dataset( | ||
parquet_name: str, | ||
deep_chem_dataset: dc.data.DiskDataset, | ||
reference_datasets: typing.List[h5py.File], | ||
): | ||
dataset_keys = deep_chem_dataset.X | ||
dataset_smiles = deep_chem_dataset.ids | ||
coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"] | ||
results = defaultdict(list) | ||
# keep track of the number of total entries, this is each conformation expanded as a unique training point | ||
total_records = 0 | ||
for key, smiles in zip(dataset_keys, dataset_smiles): | ||
for dataset in reference_datasets: | ||
if key in dataset: | ||
data_group = dataset[key] | ||
group_smiles = data_group["smiles"].asstr()[0] | ||
assert group_smiles == smiles | ||
charges = data_group["mbis-charges"][()] | ||
dipoles = data_group["dipole"][()] | ||
conformations = data_group["conformations"][()] * unit.angstrom | ||
# workout how many entries we have | ||
n_records = charges.shape[0] | ||
total_records += n_records | ||
for i in range(n_records): | ||
|
||
results["smiles"].append(smiles) | ||
results["mbis-charges"].append(charges[i]) | ||
results["dipole"].append(dipoles[i]) | ||
# make to store in bohr | ||
results["conformation"].append( | ||
conformations[i].m_as(unit.bohr).flatten() | ||
) | ||
|
||
for key, values in results.items(): | ||
assert len(values) == total_records, print(key) | ||
columns = [results[label] for label in coloumn_names] | ||
|
||
table = pyarrow.table(columns, coloumn_names) | ||
pyarrow.parquet.write_table(table, parquet_name) | ||
|
||
|
||
for file_name, dataset_name in [ | ||
("training.parquet", "maxmin-train"), | ||
("validation.parquet", "maxmin-valid"), | ||
("testing.parquet", "maxmin-test"), | ||
]: | ||
print("creating parquet for ", dataset_name) | ||
dc_dataset = dc.data.DiskDataset(dataset_name) | ||
create_parquet_dataset( | ||
parquet_name=file_name, | ||
deep_chem_dataset=dc_dataset, | ||
reference_datasets=[training_db, valid_test_db], | ||
) | ||
|
||
training_db.close() | ||
valid_test_db.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# try spliting the entire collection of data using deepchem spliters | ||
import h5py | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
dataset_keys = [] | ||
smiles_ids = [] | ||
training_set = h5py.File("TrainingSet-v1.hdf5", "r") | ||
for key, group in training_set.items(): | ||
smiles_ids.append(group["smiles"].asstr()[0]) | ||
# use the key to quickly split the datasets later | ||
dataset_keys.append(key) | ||
training_set.close() | ||
|
||
# val_set = h5py.File('ValSet-v1.hdf5', 'r') | ||
# for key, group in val_set.items(): | ||
# smiles_ids.append(group['smiles'].asstr()[0]) | ||
# dataset_keys.append(key) | ||
|
||
# val_set.close() | ||
|
||
|
||
print(f"The total number of unique molecules {len(smiles_ids)}") | ||
print("Running MaxMin Splitter ...") | ||
|
||
xs = np.array(dataset_keys) | ||
|
||
total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids) | ||
|
||
max_min_split = dc.splits.MaxMinSplitter() | ||
train, validation, test = max_min_split.train_valid_test_split( | ||
total_dataset, | ||
train_dir="maxmin-train", | ||
valid_dir="maxmin-valid", | ||
test_dir="maxmin-test", | ||
) |
Oops, something went wrong.