diff --git a/naglmbis/__init__.py b/naglmbis/__init__.py index 060f7cb..edf4351 100644 --- a/naglmbis/__init__.py +++ b/naglmbis/__init__.py @@ -2,6 +2,7 @@ naglmbis Models built with NAGL to predict MBIS properties. """ + from . import _version __version__ = _version.get_versions()["version"] diff --git a/scripts/dataset/analysis_train_val_test_split.py b/scripts/dataset/analysis_train_val_test_split.py new file mode 100644 index 0000000..453dc02 --- /dev/null +++ b/scripts/dataset/analysis_train_val_test_split.py @@ -0,0 +1,51 @@ +# collect stats on the number of molecules, range of charges and occurances of elements in each train, val and test split of the dataset +from rdkit import Chem +from rdkit.Chem import Descriptors +import deepchem as dc +import numpy as np + +ps = Chem.SmilesParserParams() +ps.removeHs = False + + +def calculate_stats(dataset_name: str): + formal_charges = {} + molecular_weights = [] + elements = {} + heavy_atom_count = [] + + # load the dataset + dataset = dc.data.DiskDataset(dataset_name) + + for smiles in dataset.ids: + mol = Chem.MolFromSmiles(smiles, ps) + charges = [] + for atom in mol.GetAtoms(): + charges.append(atom.GetFormalCharge()) + atomic_number = atom.GetAtomicNum() + if atomic_number in elements: + elements[atomic_number] += 1 + else: + elements[atomic_number] = 1 + + total_charge = sum(charges) + if total_charge in formal_charges: + formal_charges[total_charge] += 1 + else: + formal_charges[total_charge] = 1 + + molecular_weights.append(Descriptors.MolWt(mol)) + heavy_atom_count.append(Descriptors.HeavyAtomCount(mol)) + + return formal_charges, molecular_weights, elements, heavy_atom_count + + +for dataset in ["maxmin-train", "maxmin-valid", "maxmin-test"]: + charges, weights, atoms, heavy_atoms = calculate_stats(dataset_name=dataset) + print(f"Running {dataset} number of molecules {len(weights)}") + print("Total formal charges ", charges) + print("Total elements", atoms) + print(f"Average mol weight {np.mean(weights)} and std {np.std(weights)}") + print( + f"Average number of heavy atoms {np.mean(heavy_atoms)} and std {np.std(heavy_atoms)}" + ) diff --git a/scripts/dataset/setup_labeled_data.py b/scripts/dataset/setup_labeled_data.py new file mode 100644 index 0000000..d4b7469 --- /dev/null +++ b/scripts/dataset/setup_labeled_data.py @@ -0,0 +1,72 @@ +import h5py +import pyarrow +import pyarrow.parquet +from openff.units import unit +from collections import defaultdict +import deepchem as dc +import typing + +# setup the parquet datasets using the splits generated by deepchem + + +# load up both files +training_db = h5py.File("TrainingSet-v1.hdf5", "r") +valid_test_db = h5py.File("ValSet-v1.hdf5", "r") + + +def create_parquet_dataset( + parquet_name: str, + deep_chem_dataset: dc.data.DiskDataset, + reference_datasets: typing.List[h5py.File], +): + dataset_keys = deep_chem_dataset.X + dataset_smiles = deep_chem_dataset.ids + coloumn_names = ["smiles", "conformation", "dipole", "mbis-charges"] + results = defaultdict(list) + # keep track of the number of total entries, this is each conformation expanded as a unique training point + total_records = 0 + for key, smiles in zip(dataset_keys, dataset_smiles): + for dataset in reference_datasets: + if key in dataset: + data_group = dataset[key] + group_smiles = data_group["smiles"].asstr()[0] + assert group_smiles == smiles + charges = data_group["mbis-charges"][()] + dipoles = data_group["dipole"][()] + conformations = data_group["conformations"][()] * unit.angstrom + # workout how many entries we have + n_records = charges.shape[0] + total_records += n_records + for i in range(n_records): + + results["smiles"].append(smiles) + results["mbis-charges"].append(charges[i]) + results["dipole"].append(dipoles[i]) + # make to store in bohr + results["conformation"].append( + conformations[i].m_as(unit.bohr).flatten() + ) + + for key, values in results.items(): + assert len(values) == total_records, print(key) + columns = [results[label] for label in coloumn_names] + + table = pyarrow.table(columns, coloumn_names) + pyarrow.parquet.write_table(table, parquet_name) + + +for file_name, dataset_name in [ + ("training.parquet", "maxmin-train"), + ("validation.parquet", "maxmin-valid"), + ("testing.parquet", "maxmin-test"), +]: + print("creating parquet for ", dataset_name) + dc_dataset = dc.data.DiskDataset(dataset_name) + create_parquet_dataset( + parquet_name=file_name, + deep_chem_dataset=dc_dataset, + reference_datasets=[training_db, valid_test_db], + ) + +training_db.close() +valid_test_db.close() diff --git a/scripts/dataset/split_by_deepchem.py b/scripts/dataset/split_by_deepchem.py new file mode 100644 index 0000000..c023451 --- /dev/null +++ b/scripts/dataset/split_by_deepchem.py @@ -0,0 +1,36 @@ +# try spliting the entire collection of data using deepchem spliters +import h5py +import deepchem as dc +import numpy as np + +dataset_keys = [] +smiles_ids = [] +training_set = h5py.File("TrainingSet-v1.hdf5", "r") +for key, group in training_set.items(): + smiles_ids.append(group["smiles"].asstr()[0]) + # use the key to quickly split the datasets later + dataset_keys.append(key) +training_set.close() + +# val_set = h5py.File('ValSet-v1.hdf5', 'r') +# for key, group in val_set.items(): +# smiles_ids.append(group['smiles'].asstr()[0]) +# dataset_keys.append(key) + +# val_set.close() + + +print(f"The total number of unique molecules {len(smiles_ids)}") +print("Running MaxMin Splitter ...") + +xs = np.array(dataset_keys) + +total_dataset = dc.data.DiskDataset.from_numpy(X=xs, ids=smiles_ids) + +max_min_split = dc.splits.MaxMinSplitter() +train, validation, test = max_min_split.train_valid_test_split( + total_dataset, + train_dir="maxmin-train", + valid_dir="maxmin-valid", + test_dir="maxmin-test", +) diff --git a/scripts/training/train_model.py b/scripts/training/train_model.py new file mode 100644 index 0000000..38cdd16 --- /dev/null +++ b/scripts/training/train_model.py @@ -0,0 +1,232 @@ +# Test training script to make sure dipole prediction works +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +import torch +from pytorch_lightning.loggers import MLFlowLogger + +from nagl.config import Config, DataConfig, ModelConfig, OptimizerConfig +from nagl.config.data import Dataset, DipoleTarget, ReadoutTarget +from nagl.config.model import GCNConvolutionModule, ReadoutModule, Sequential +from nagl.features import ( + AtomConnectivity, + AtomFeature, + AtomicElement, + BondFeature, + AtomFeature, + register_atom_feature, + _CUSTOM_ATOM_FEATURES, +) +from nagl.training import DGLMoleculeDataModule, DGLMoleculeLightningModel +import typing +import logging +import pathlib +import pydantic +from rdkit import Chem +import dataclasses + +DEFAULT_RING_SIZES = [3, 4, 5, 6, 7, 8] + + +# define our ring membership feature +@pydantic.dataclasses.dataclass(config={"extra": pydantic.Extra.forbid}) +class AtomInRingOfSize(AtomFeature): + type: typing.Literal["ringofsize"] = "ringofsize" + ring_sizes: typing.List[pydantic.PositiveInt] = pydantic.Field( + DEFAULT_RING_SIZES, + description="The size of the ring we want to check membership of", + ) + + def __len__(self): + return len(self.ring_sizes) + + def __call__(self, molecule: Chem.Mol) -> torch.Tensor: + ring_info: Chem.RingInfo = molecule.GetRingInfo() + + return torch.vstack( + [ + torch.Tensor( + [ + int(ring_info.IsAtomInRingOfSize(atom.GetIdx(), ring_size)) + for ring_size in self.ring_sizes + ] + ) + for atom in molecule.GetAtoms() + ] + ) + + +def configure_model( + atom_features: typing.List[AtomFeature], + bond_features: typing.List[BondFeature], + n_gcn_layers: int, + n_gcn_hidden_features: int, + n_am1_layers: int, + n_am1_hidden_features: int, +) -> ModelConfig: + return ModelConfig( + atom_features=atom_features, + bond_features=bond_features, + convolution=GCNConvolutionModule( + type="SAGEConv", + hidden_feats=[n_gcn_hidden_features] * n_gcn_layers, + activation=["ReLU"] * n_gcn_layers, + ), + readouts={ + "mbis-charges": ReadoutModule( + pooling="atom", + forward=Sequential( + hidden_feats=[n_am1_hidden_features] * n_am1_layers + [2], + activation=["ReLU"] * n_am1_layers + ["Identity"], + ), + postprocess="charges", + ) + }, + ) + + +def configure_data() -> DataConfig: + return DataConfig( + training=Dataset( + sources=["../datasets/training.parquet"], + # The 'column' must match one of the label columns in the parquet + # table that was create during stage 000. + # The 'readout' column should correspond to one our or model readout + # keys. + # denom for charge in e and dipole in e*bohr 0.1D~ + targets=[ + ReadoutTarget( + column="mbis-charges", + readout="mbis-charges", + metric="rmse", + denominator=0.02, + ), + DipoleTarget( + metric="rmse", + dipole_column="dipole", + conformation_column="conformation", + charge_label="mbis-charges", + denominator=0.04, + ), + ], + batch_size=250, + ), + validation=Dataset( + sources=["../datasets/validation.parquet"], + targets=[ + ReadoutTarget( + column="mbis-charges", + readout="mbis-charges", + metric="rmse", + denominator=0.02, + ), + DipoleTarget( + metric="rmse", + dipole_column="dipole", + conformation_column="conformation", + charge_label="mbis-charges", + denominator=0.04, + ), + ], + ), + test=Dataset( + sources=["../datasets/testing.parquet"], + targets=[ + ReadoutTarget( + column="mbis-charges", + readout="mbis-charges", + metric="rmse", + denominator=0.02, + ), + DipoleTarget( + metric="rmse", + dipole_column="dipole", + conformation_column="conformation", + charge_label="mbis-charges", + denominator=0.04, + ), + ], + ), + ) + + +def configure_optimizer(lr: float) -> OptimizerConfig: + return OptimizerConfig(type="Adam", lr=lr) + + +def main(): + logging.basicConfig(level=logging.INFO) + output_dir = pathlib.Path("001-train-charge-model-small-mols") + + register_atom_feature(AtomInRingOfSize) + print(_CUSTOM_ATOM_FEATURES) + # Configure our model, data sets, and optimizer. + model_config = configure_model( + atom_features=[ + AtomicElement(values=["H", "C", "N", "O", "F", "P", "S", "Cl", "Br"]), + AtomConnectivity(), + dataclasses.asdict(AtomInRingOfSize()), + ], + bond_features=[], + n_gcn_layers=5, + n_gcn_hidden_features=128, + n_am1_layers=2, + n_am1_hidden_features=64, + ) + data_config = configure_data() + + optimizer_config = configure_optimizer(0.001) + + # Define the model and lightning data module that will contain the train, val, + # and test dataloaders if specified in ``data_config``. + config = Config(model=model_config, data=data_config, optimizer=optimizer_config) + + model = DGLMoleculeLightningModel(config) + model.to_yaml("charge-dipole-v1.yaml") + print("Model", model) + + # The 'cache_dir' will store the fully featurized molecules so we don't need to + # re-compute these each to we adjust a hyperparameter for example. + data = DGLMoleculeDataModule(config, cache_dir=output_dir / "feature-cache") + + # Define an MLFlow experiment to store the outputs of training this model. This + # Will include the usual statistics as well as useful artifacts highlighting + # the models weak spots. + logger = MLFlowLogger( + experiment_name="mbis-charge-dipole-model-small-mols-1000", + save_dir=str(output_dir / "mlruns"), + log_model="all", + ) + + # The MLFlow UI can be opened by running: + # + # mlflow ui --backend-store-uri ./001-train-charge-model/mlruns \ + # --default-artifact-root ./001-train-charge-model/mlruns + # + + # Train the model + n_epochs = 1000 + + n_gpus = 0 if not torch.cuda.is_available() else 1 + print(f"Using {n_gpus} GPUs") + + model_checkpoint = ModelCheckpoint( + monitor="val/loss", dirpath=output_dir.joinpath("") + ) + trainer = pl.Trainer( + accelerator="cpu", + # devices=n_gpus, + min_epochs=n_epochs, + max_epochs=n_epochs, + logger=logger, + log_every_n_steps=50, + callbacks=[model_checkpoint], + ) + + trainer.fit(model, datamodule=data) + trainer.test(model, datamodule=data) + + print(model_checkpoint.best_model_path) + + +if __name__ == "__main__": + main()