diff --git a/docs/source/api/utils.mols.rst b/docs/source/api/utils.mols.rst index 5596977e..e877a259 100644 --- a/docs/source/api/utils.mols.rst +++ b/docs/source/api/utils.mols.rst @@ -49,6 +49,9 @@ three common graph constructions: dgllife.utils.k_nearest_neighbors dgllife.utils.mol_to_nearest_neighbor_graph dgllife.utils.smiles_to_nearest_neighbor_graph + dgllife.utils.ToGraph + dgllife.utils.MolToBigraph + dgllife.utils.SMILESToBigraph Featurization for Molecules --------------------------- diff --git a/examples/property_prediction/MTL/main.py b/examples/property_prediction/MTL/main.py index bd697a73..487c104a 100644 --- a/examples/property_prediction/MTL/main.py +++ b/examples/property_prediction/MTL/main.py @@ -31,7 +31,7 @@ from argparse import ArgumentParser from dgllife.data import MoleculeCSVDataset - from dgllife.utils import smiles_to_bigraph, RandomSplitter + from dgllife.utils import SMILESToBigraph, RandomSplitter from configure import configs from run import main @@ -67,13 +67,14 @@ # Setup for experiments mkdir_p(args['result_path']) - node_featurizer = atom_featurizer edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True) df = pd.read_csv(args['csv_path']) + + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer, + edge_featurizer=edge_featurizer) + dataset = MoleculeCSVDataset( - df, partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=node_featurizer, - edge_featurizer=edge_featurizer, + df, smiles_to_g, smiles_column=args['smiles_column'], cache_file_path=args['result_path'] + '/graph.bin', task_names=args['tasks'] @@ -84,4 +85,4 @@ dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1, random_state=0) - main(args, node_featurizer, edge_featurizer, train_set, val_set, test_set) + main(args, atom_featurizer, edge_featurizer, train_set, val_set, test_set) diff --git a/examples/property_prediction/csv_data_configuration/classification_inference.py b/examples/property_prediction/csv_data_configuration/classification_inference.py index 584fcc86..9979d43e 100644 --- a/examples/property_prediction/csv_data_configuration/classification_inference.py +++ b/examples/property_prediction/csv_data_configuration/classification_inference.py @@ -9,17 +9,17 @@ import torch from dgllife.data import UnlabeledSMILES -from dgllife.utils import mol_to_bigraph -from functools import partial +from dgllife.utils import MolToBigraph from torch.utils.data import DataLoader from tqdm import tqdm from utils import mkdir_p, collate_molgraphs_unlabeled, load_model, predict, init_featurizer def main(args): - dataset = UnlabeledSMILES(args['smiles'], node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], - mol_to_graph=partial(mol_to_bigraph, add_self_loop=True)) + mol_to_g = MolToBigraph(add_self_loop=True, + node_featurizer=args['node_featurizer'], + edge_featurizer=args['edge_featurizer']) + dataset = UnlabeledSMILES(args['smiles'], mol_to_graph=mol_to_g) dataloader = DataLoader(dataset, batch_size=args['batch_size'], collate_fn=collate_molgraphs_unlabeled, num_workers=args['num_workers']) model = load_model(args).to(args['device']) diff --git a/examples/property_prediction/csv_data_configuration/regression_inference.py b/examples/property_prediction/csv_data_configuration/regression_inference.py index 57281fee..24f34433 100644 --- a/examples/property_prediction/csv_data_configuration/regression_inference.py +++ b/examples/property_prediction/csv_data_configuration/regression_inference.py @@ -9,17 +9,17 @@ import torch from dgllife.data import UnlabeledSMILES -from dgllife.utils import mol_to_bigraph -from functools import partial +from dgllife.utils import MolToBigraph from torch.utils.data import DataLoader from tqdm import tqdm from utils import mkdir_p, collate_molgraphs_unlabeled, load_model, predict, init_featurizer def main(args): - dataset = UnlabeledSMILES(args['smiles'], node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], - mol_to_graph=partial(mol_to_bigraph, add_self_loop=True)) + mol_to_g = MolToBigraph(add_self_loop=True, + node_featurizer=args['node_featurizer'], + edge_featurizer=args['edge_featurizer']) + dataset = UnlabeledSMILES(args['smiles'], mol_to_graph=mol_to_g) dataloader = DataLoader(dataset, batch_size=args['batch_size'], collate_fn=collate_molgraphs_unlabeled, num_workers=args['num_workers']) model = load_model(args).to(args['device']) diff --git a/examples/property_prediction/csv_data_configuration/utils.py b/examples/property_prediction/csv_data_configuration/utils.py index f550cddb..0421a40e 100644 --- a/examples/property_prediction/csv_data_configuration/utils.py +++ b/examples/property_prediction/csv_data_configuration/utils.py @@ -11,8 +11,7 @@ import torch.nn.functional as F from dgllife.data import MoleculeCSVDataset -from dgllife.utils import smiles_to_bigraph, ScaffoldSplitter, RandomSplitter, mol_to_bigraph -from functools import partial +from dgllife.utils import SMILESToBigraph, ScaffoldSplitter, RandomSplitter def init_featurizer(args): """Initialize node/edge featurizer @@ -60,10 +59,10 @@ def init_featurizer(args): return args def load_dataset(args, df): + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'], + edge_featurizer=args['edge_featurizer']) dataset = MoleculeCSVDataset(df=df, - smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + smiles_to_graph=smiles_to_g, smiles_column=args['smiles_column'], cache_file_path=args['result_path'] + '/graph.bin', task_names=args['task_names'], diff --git a/examples/property_prediction/moleculenet/classification.py b/examples/property_prediction/moleculenet/classification.py index f894116d..3d6dd61a 100644 --- a/examples/property_prediction/moleculenet/classification.py +++ b/examples/property_prediction/moleculenet/classification.py @@ -8,8 +8,7 @@ import torch.nn as nn from dgllife.model import load_pretrained -from dgllife.utils import smiles_to_bigraph, EarlyStopping, Meter -from functools import partial +from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph from torch.optim import Adam from torch.utils.data import DataLoader @@ -160,59 +159,43 @@ def main(args, exp_config, train_set, val_set, test_set): args = init_featurizer(args) mkdir_p(args['result_path']) + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'], + edge_featurizer=args['edge_featurizer']) if args['dataset'] == 'MUV': from dgllife.data import MUV - dataset = MUV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = MUV(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'BACE': from dgllife.data import BACE - dataset = BACE(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = BACE(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'BBBP': from dgllife.data import BBBP - dataset = BBBP(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = BBBP(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'ClinTox': from dgllife.data import ClinTox - dataset = ClinTox(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = ClinTox(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'SIDER': from dgllife.data import SIDER - dataset = SIDER(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = SIDER(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'ToxCast': from dgllife.data import ToxCast - dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = ToxCast(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'HIV': from dgllife.data import HIV - dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = HIV(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'PCBA': from dgllife.data import PCBA - dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = PCBA(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'Tox21': from dgllife.data import Tox21 - dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = Tox21(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) else: raise ValueError('Unexpected dataset: {}'.format(args['dataset'])) diff --git a/examples/property_prediction/moleculenet/regression.py b/examples/property_prediction/moleculenet/regression.py index f9232987..4030b379 100644 --- a/examples/property_prediction/moleculenet/regression.py +++ b/examples/property_prediction/moleculenet/regression.py @@ -8,8 +8,7 @@ import torch.nn as nn from dgllife.model import load_pretrained -from dgllife.utils import smiles_to_bigraph, EarlyStopping, Meter -from functools import partial +from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph from torch.optim import Adam from torch.utils.data import DataLoader @@ -157,23 +156,19 @@ def main(args, exp_config, train_set, val_set, test_set): args = init_featurizer(args) mkdir_p(args['result_path']) + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'], + edge_featurizer=args['edge_featurizer']) if args['dataset'] == 'FreeSolv': from dgllife.data import FreeSolv - dataset = FreeSolv(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = FreeSolv(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'Lipophilicity': from dgllife.data import Lipophilicity - dataset = Lipophilicity(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = Lipophilicity(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'ESOL': from dgllife.data import ESOL - dataset = ESOL(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=args['node_featurizer'], - edge_featurizer=args['edge_featurizer'], + dataset = ESOL(smiles_to_graph=smiles_to_g, n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) else: raise ValueError('Unexpected dataset: {}'.format(args['dataset'])) diff --git a/examples/property_prediction/pretrain_gnns/chem/README.md b/examples/property_prediction/pretrain_gnns/chem/README.md index 18ef1aee..8bfd46e3 100644 --- a/examples/property_prediction/pretrain_gnns/chem/README.md +++ b/examples/property_prediction/pretrain_gnns/chem/README.md @@ -16,7 +16,7 @@ This is a DGL implementation of the following paper based on PyTorch. This paper purposed an attribute masking pre-training method. It randomly masks input node/edge attributes by replacing them with special masked indicators, then the GNN will predict those attributes based on neighboring structure. ``` bash -python pretrain_masking.py --output_model_file OUTPUT_MODEL_FILE +python pretrain_masking.py --output_model_file OUTPUT_MODEL_FILE ``` The self-supervised pre-training model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_masking.pth). diff --git a/examples/property_prediction/pretrain_gnns/chem/classification.py b/examples/property_prediction/pretrain_gnns/chem/classification.py index 73b2ed62..bcb0e9a9 100644 --- a/examples/property_prediction/pretrain_gnns/chem/classification.py +++ b/examples/property_prediction/pretrain_gnns/chem/classification.py @@ -4,14 +4,13 @@ # https://github.com/awslabs/dgl-lifesci/blob/master/examples/property_prediction/moleculenet/classification.py import argparse -from functools import partial import numpy as np import torch import torch.nn as nn from dgllife.utils import PretrainAtomFeaturizer from dgllife.utils import PretrainBondFeaturizer -from dgllife.utils import smiles_to_bigraph, Meter, EarlyStopping +from dgllife.utils import Meter, EarlyStopping, SMILESToBigraph from dgllife.model.model_zoo.gin_predictor import GINPredictor from torch.utils.data import DataLoader @@ -155,69 +154,53 @@ def main(args, dataset, device): atom_featurizer = PretrainAtomFeaturizer() bond_featurizer = PretrainBondFeaturizer() + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer, + edge_featurizer=bond_featurizer) if args.dataset == 'MUV': from dgllife.data import MUV - dataset = MUV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = MUV(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'BACE': from dgllife.data import BACE - dataset = BACE(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = BACE(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'BBBP': from dgllife.data import BBBP - dataset = BBBP(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = BBBP(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'ClinTox': from dgllife.data import ClinTox - dataset = ClinTox(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = ClinTox(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'SIDER': from dgllife.data import SIDER - dataset = SIDER(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = SIDER(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'ToxCast': from dgllife.data import ToxCast - dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = ToxCast(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'HIV': from dgllife.data import HIV - dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = HIV(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'PCBA': from dgllife.data import PCBA - dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = PCBA(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) elif args.dataset == 'Tox21': from dgllife.data import Tox21 - dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + dataset = Tox21(smiles_to_graph=smiles_to_g, n_jobs=1 if args.num_workers == 0 else args.num_workers) else: raise ValueError('Unexpected dataset: {}'.format(args.dataset)) diff --git a/examples/property_prediction/pretrain_gnns/chem/pretrain_masking.py b/examples/property_prediction/pretrain_gnns/chem/pretrain_masking.py index 501973cf..40371506 100644 --- a/examples/property_prediction/pretrain_gnns/chem/pretrain_masking.py +++ b/examples/property_prediction/pretrain_gnns/chem/pretrain_masking.py @@ -12,9 +12,7 @@ import random import dgl -from dgllife.utils import PretrainAtomFeaturizer -from dgllife.utils import PretrainBondFeaturizer -from dgllife.utils import smiles_to_bigraph +from dgllife.utils import PretrainAtomFeaturizer, PretrainBondFeaturizer, SMILESToBigraph from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from dgllife.model.gnn.gin import GIN @@ -184,10 +182,10 @@ def main(): atom_featurizer = PretrainAtomFeaturizer() bond_featurizer = PretrainBondFeaturizer() + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer, + edge_featurizer=bond_featurizer) dataset = PretrainDataset(data=data, - smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + smiles_to_graph=smiles_to_g, smiles_column='smiles', task='masking') diff --git a/examples/property_prediction/pretrain_gnns/chem/pretrain_supervised.py b/examples/property_prediction/pretrain_gnns/chem/pretrain_supervised.py index e8a8102e..258af5d3 100644 --- a/examples/property_prediction/pretrain_gnns/chem/pretrain_supervised.py +++ b/examples/property_prediction/pretrain_gnns/chem/pretrain_supervised.py @@ -2,16 +2,13 @@ import argparse import pickle import tqdm -from functools import partial import torch import torch.nn as nn from torch.utils.data import DataLoader import dgl -from dgllife.utils import PretrainAtomFeaturizer -from dgllife.utils import PretrainBondFeaturizer -from dgllife.utils import smiles_to_bigraph +from dgllife.utils import PretrainAtomFeaturizer, PretrainBondFeaturizer, SMILESToBigraph from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from dgllife.model.model_zoo.gin_predictor import GINPredictor @@ -148,11 +145,10 @@ def main(): atom_featurizer = PretrainAtomFeaturizer() bond_featurizer = PretrainBondFeaturizer() + smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer, + edge_featurizer=bond_featurizer) dataset = PretrainDataset(data=data, - smiles_to_graph=partial( - smiles_to_bigraph, add_self_loop=True), - node_featurizer=atom_featurizer, - edge_featurizer=bond_featurizer, + smiles_to_graph=smiles_to_g, task='supervised') train_dataloader = DataLoader(dataset=dataset, diff --git a/examples/property_prediction/pretrain_gnns/chem/utils.py b/examples/property_prediction/pretrain_gnns/chem/utils.py index 7b028e15..4ba67d23 100644 --- a/examples/property_prediction/pretrain_gnns/chem/utils.py +++ b/examples/property_prediction/pretrain_gnns/chem/utils.py @@ -27,22 +27,18 @@ class PretrainDataset(object): used for pretrain_masking(task=masking) and pretrain_supervised(task=supervised) task. """ - def __init__(self, data, smiles_to_graph, node_featurizer, edge_featurizer, smiles_column=None, task=None): + def __init__(self, data, smiles_to_graph, smiles_column=None, task=None): self.data = data self.smiles_column = smiles_column if task == 'masking': self.smiles = self.data[smiles_column].tolist() self.smiles_to_graph = smiles_to_graph - self.node_featurizer = node_featurizer - self.edge_featurizer = edge_featurizer self.task = task self._pre_process() def __getitem__(self, item): s = self.smiles[item] - graph = self.smiles_to_graph(s, - node_featurizer=self.node_featurizer, - edge_featurizer=self.edge_featurizer) + graph = self.smiles_to_graph(s) if self.task == 'masking': return graph elif self.task == 'supervised': diff --git a/python/dgllife/data/bace.py b/python/dgllife/data/bace.py index 0faa5416..baa9f4cf 100644 --- a/python/dgllife/data/bace.py +++ b/python/dgllife/data/bace.py @@ -12,7 +12,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['BACE'] @@ -31,8 +30,8 @@ class BACE(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -56,9 +55,10 @@ class BACE(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import BACE - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = BACE(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = BACE(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 1513 @@ -95,7 +95,7 @@ class BACE(MoleculeCSVDataset): tensor([0.2594]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/bbbp.py b/python/dgllife/data/bbbp.py index 617350a4..872f57b5 100644 --- a/python/dgllife/data/bbbp.py +++ b/python/dgllife/data/bbbp.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['BBBP'] @@ -33,8 +32,8 @@ class BBBP(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -58,9 +57,10 @@ class BBBP(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import BBBP - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = BBBP(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = BBBP(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 2039 @@ -97,7 +97,7 @@ class BBBP(MoleculeCSVDataset): tensor([0.7123]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/clintox.py b/python/dgllife/data/clintox.py index f0343fdf..f09c4a66 100644 --- a/python/dgllife/data/clintox.py +++ b/python/dgllife/data/clintox.py @@ -11,7 +11,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['ClinTox'] @@ -34,8 +33,8 @@ class ClinTox(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -59,9 +58,10 @@ class ClinTox(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import ClinTox - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = ClinTox(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = ClinTox(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 1478 @@ -82,7 +82,7 @@ class ClinTox(MoleculeCSVDataset): tensor([ 0.0684, 10.9048]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/csv_dataset.py b/python/dgllife/data/csv_dataset.py index 49776838..8cb8436d 100644 --- a/python/dgllife/data/csv_dataset.py +++ b/python/dgllife/data/csv_dataset.py @@ -12,8 +12,10 @@ import torch from dgl.data.utils import save_graphs, load_graphs +from functools import partial from ..utils.io import pmap +from ..utils.mol_to_graph import ToGraph, SMILESToBigraph __all__ = ['MoleculeCSVDataset'] @@ -33,7 +35,8 @@ class MoleculeCSVDataset(object): Dataframe including smiles and labels. Can be loaded by pandas.read_csv(file_path). One column includes smiles and some other columns include labels. smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : None or callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. @@ -63,9 +66,9 @@ class MoleculeCSVDataset(object): Path to a CSV file of molecules that RDKit failed to parse. If not specified, the molecules will not be recorded. """ - def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer, smiles_column, - cache_file_path, task_names=None, load=False, log_every=1000, init_mask=True, - n_jobs=1, error_log=None): + def __init__(self, df, smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, + smiles_column=None, cache_file_path=None, task_names=None, load=False, + log_every=1000, init_mask=True, n_jobs=1, error_log=None): self.df = df self.smiles = self.df[smiles_column].tolist() if task_names is None: @@ -74,14 +77,25 @@ def __init__(self, df, smiles_to_graph, node_featurizer, edge_featurizer, smiles self.task_names = task_names self.n_tasks = len(self.task_names) self.cache_file_path = cache_file_path - self._pre_process(smiles_to_graph, node_featurizer, edge_featurizer, - load, log_every, init_mask, n_jobs, error_log) + + if isinstance(smiles_to_graph, ToGraph): + assert node_featurizer is None, \ + 'Initialize smiles_to_graph object with node_featurizer=node_featurizer' + assert edge_featurizer is None, \ + 'Initialize smiles_to_graph object with edge_featurizer=edge_featurizer' + elif smiles_to_graph is None: + smiles_to_graph = SMILESToBigraph(node_featurizer=node_featurizer, + edge_featurizer=edge_featurizer) + else: + smiles_to_graph = partial(smiles_to_graph, node_featurizer=node_featurizer, + edge_featurizer=edge_featurizer) + + self._pre_process(smiles_to_graph, load, log_every, init_mask, n_jobs, error_log) # Only useful for binary classification tasks self._task_pos_weights = None - def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, - load, log_every, init_mask, n_jobs, error_log): + def _pre_process(self, smiles_to_graph, load, log_every, init_mask, n_jobs, error_log): """Pre-process the dataset * Convert molecules from smiles format into DGLGraphs @@ -93,12 +107,6 @@ def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, ---------- smiles_to_graph : callable, SMILES -> DGLGraph Function for converting a SMILES (str) into a DGLGraph. - node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict - Featurization for nodes like atoms in a molecule, which can be used to update - ndata for a DGLGraph. - edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict - Featurization for edges like bonds in a molecule, which can be used to update - edata for a DGLGraph. load : bool Whether to load the previously pre-processed dataset or pre-process from scratch. ``load`` should be False when we want to try different graph construction and @@ -127,16 +135,13 @@ def _pre_process(self, smiles_to_graph, node_featurizer, edge_featurizer, if n_jobs > 1: self.graphs = pmap(smiles_to_graph, self.smiles, - node_featurizer=node_featurizer, - edge_featurizer=edge_featurizer, n_jobs=n_jobs) else: self.graphs = [] for i, s in enumerate(self.smiles): if (i + 1) % log_every == 0: print('Processing molecule {:d}/{:d}'.format(i+1, len(self))) - self.graphs.append(smiles_to_graph(s, node_featurizer=node_featurizer, - edge_featurizer=edge_featurizer)) + self.graphs.append(smiles_to_graph(s)) # Keep only valid molecules self.valid_ids = [] diff --git a/python/dgllife/data/esol.py b/python/dgllife/data/esol.py index 01b4bb6f..8304d680 100644 --- a/python/dgllife/data/esol.py +++ b/python/dgllife/data/esol.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['ESOL'] @@ -32,8 +31,8 @@ class ESOL(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -56,9 +55,10 @@ class ESOL(MoleculeCSVDataset): -------- >>> from dgllife.data import ESOL - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = ESOL(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = ESOL(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 1128 @@ -104,7 +104,7 @@ class ESOL(MoleculeCSVDataset): 202.32) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/freesolv.py b/python/dgllife/data/freesolv.py index 6d8bb4ab..5813fd0b 100644 --- a/python/dgllife/data/freesolv.py +++ b/python/dgllife/data/freesolv.py @@ -11,7 +11,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['FreeSolv'] @@ -34,8 +33,8 @@ class FreeSolv(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -58,9 +57,10 @@ class FreeSolv(MoleculeCSVDataset): -------- >>> from dgllife.data import FreeSolv - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = FreeSolv(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = FreeSolv(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 642 @@ -92,7 +92,7 @@ class FreeSolv(MoleculeCSVDataset): -9.625) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/hiv.py b/python/dgllife/data/hiv.py index 05e5dcef..e4ca8b58 100644 --- a/python/dgllife/data/hiv.py +++ b/python/dgllife/data/hiv.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['HIV'] @@ -31,8 +30,8 @@ class HIV(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -56,9 +55,10 @@ class HIV(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import HIV - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = HIV(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = HIV(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 41127 @@ -95,7 +95,7 @@ class HIV(MoleculeCSVDataset): tensor([33.1880]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/lipophilicity.py b/python/dgllife/data/lipophilicity.py index 9b4a4f59..105073f6 100644 --- a/python/dgllife/data/lipophilicity.py +++ b/python/dgllife/data/lipophilicity.py @@ -11,7 +11,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['Lipophilicity'] @@ -32,8 +31,8 @@ class Lipophilicity(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -56,9 +55,10 @@ class Lipophilicity(MoleculeCSVDataset): -------- >>> from dgllife.data import Lipophilicity - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = Lipophilicity(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = Lipophilicity(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 4200 @@ -86,7 +86,7 @@ class Lipophilicity(MoleculeCSVDataset): 'CHEMBL596271') """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/muv.py b/python/dgllife/data/muv.py index 0791c075..25673e11 100644 --- a/python/dgllife/data/muv.py +++ b/python/dgllife/data/muv.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['MUV'] @@ -30,8 +29,8 @@ class MUV(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -55,9 +54,10 @@ class MUV(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import MUV - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = MUV(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = MUV(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 93087 @@ -96,7 +96,7 @@ class MUV(MoleculeCSVDataset): 1262.8000, 702.1111, 571.3636, 528.0000, 485.2308]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/pcba.py b/python/dgllife/data/pcba.py index 0afba449..24841955 100644 --- a/python/dgllife/data/pcba.py +++ b/python/dgllife/data/pcba.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['PCBA'] @@ -29,8 +28,8 @@ class PCBA(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -54,9 +53,10 @@ class PCBA(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import PCBA - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = PCBA(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = PCBA(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 437929 @@ -93,7 +93,7 @@ class PCBA(MoleculeCSVDataset): tensor([7.3400, 489.0000, ..., 1.0000]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/sider.py b/python/dgllife/data/sider.py index b0c26eda..5cfcb308 100644 --- a/python/dgllife/data/sider.py +++ b/python/dgllife/data/sider.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['SIDER'] @@ -28,8 +27,8 @@ class SIDER(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -53,9 +52,10 @@ class SIDER(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import SIDER - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = SIDER(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = SIDER(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 1427 @@ -81,7 +81,7 @@ class SIDER(MoleculeCSVDataset): 0.5060, 0.1136, 0.5106]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/smiles_inference.py b/python/dgllife/data/smiles_inference.py index 991e0bcf..a5025be4 100644 --- a/python/dgllife/data/smiles_inference.py +++ b/python/dgllife/data/smiles_inference.py @@ -5,9 +5,10 @@ # # Dataset for inference on smiles +from functools import partial from rdkit import Chem -from ..utils.mol_to_graph import mol_to_bigraph +from ..utils.mol_to_graph import ToGraph, MolToBigraph __all__ = ['UnlabeledSMILES'] @@ -33,7 +34,7 @@ class UnlabeledSMILES(object): log_every : bool Print a message every time ``log_every`` molecules are processed. Default to 1000. """ - def __init__(self, smiles_list, mol_to_graph=mol_to_bigraph, node_featurizer=None, + def __init__(self, smiles_list, mol_to_graph=None, node_featurizer=None, edge_featurizer=None, log_every=1000): super(UnlabeledSMILES, self).__init__() @@ -48,11 +49,24 @@ def __init__(self, smiles_list, mol_to_graph=mol_to_bigraph, node_featurizer=Non self.smiles = canonical_smiles self.graphs = [] + + if mol_to_graph is None: + mol_to_graph = MolToBigraph() + + # Check for backward compatibility + if isinstance(mol_to_graph, ToGraph): + assert node_featurizer is None, \ + 'Initialize mol_to_graph object with node_featurizer=node_featurizer' + assert edge_featurizer is None, \ + 'Initialize mol_to_graph object with edge_featurizer=edge_featurizer' + else: + mol_to_graph = partial(mol_to_graph, node_featurizer=node_featurizer, + edge_featurizer=edge_featurizer) + for i, mol in enumerate(mol_list): if (i + 1) % log_every == 0: print('Processing molecule {:d}/{:d}'.format(i + 1, len(self))) - self.graphs.append(mol_to_graph(mol, node_featurizer=node_featurizer, - edge_featurizer=edge_featurizer)) + self.graphs.append(mol_to_graph(mol)) def __getitem__(self, item): """Get datapoint with index diff --git a/python/dgllife/data/tox21.py b/python/dgllife/data/tox21.py index 5d8c5581..077488bf 100644 --- a/python/dgllife/data/tox21.py +++ b/python/dgllife/data/tox21.py @@ -10,7 +10,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['Tox21'] @@ -33,8 +32,8 @@ class Tox21(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -57,9 +56,10 @@ class Tox21(MoleculeCSVDataset): -------- >>> from dgllife.data import Tox21 - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = Tox21(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 7831 @@ -96,7 +96,7 @@ class Tox21(MoleculeCSVDataset): tensor([26.9706, 35.3750, 5.9756, 21.6364, 6.4404, 21.4500, 26.0000, 5.0826, 21.4390, 14.7692, 6.1442, 12.4308]) """ - def __init__(self, smiles_to_graph=smiles_to_bigraph, + def __init__(self, smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/data/toxcast.py b/python/dgllife/data/toxcast.py index 0fd0e541..da8e923b 100644 --- a/python/dgllife/data/toxcast.py +++ b/python/dgllife/data/toxcast.py @@ -11,7 +11,6 @@ from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive from .csv_dataset import MoleculeCSVDataset -from ..utils.mol_to_graph import smiles_to_bigraph __all__ = ['ToxCast'] @@ -31,8 +30,8 @@ class ToxCast(MoleculeCSVDataset): Parameters ---------- smiles_to_graph: callable, str -> DGLGraph - A function turning a SMILES string into a DGLGraph. - Default to :func:`dgllife.utils.smiles_to_bigraph`. + A function turning a SMILES string into a DGLGraph. If None, it uses + :func:`dgllife.utils.SMILESToBigraph` by default. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. @@ -56,9 +55,10 @@ class ToxCast(MoleculeCSVDataset): >>> import torch >>> from dgllife.data import ToxCast - >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer + >>> from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer - >>> dataset = ToxCast(smiles_to_bigraph, CanonicalAtomFeaturizer()) + >>> smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer()) + >>> dataset = ToxCast(smiles_to_g) >>> # Get size of the dataset >>> len(dataset) 8576 @@ -79,7 +79,7 @@ class ToxCast(MoleculeCSVDataset): tensor([4.0435e+00, ..., 1.7500e+01]) """ def __init__(self, - smiles_to_graph=smiles_to_bigraph, + smiles_to_graph=None, node_featurizer=None, edge_featurizer=None, load=False, diff --git a/python/dgllife/utils/mol_to_graph.py b/python/dgllife/utils/mol_to_graph.py index 9d235eae..9b3416e4 100644 --- a/python/dgllife/utils/mol_to_graph.py +++ b/python/dgllife/utils/mol_to_graph.py @@ -26,7 +26,10 @@ 'mol_to_complete_graph', 'k_nearest_neighbors', 'mol_to_nearest_neighbor_graph', - 'smiles_to_nearest_neighbor_graph'] + 'smiles_to_nearest_neighbor_graph', + 'ToGraph', + 'MolToBigraph', + 'SMILESToBigraph'] # pylint: disable=I1101 def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, @@ -1027,3 +1030,209 @@ def smiles_to_nearest_neighbor_graph(smiles, mol, coordinates, neighbor_cutoff, max_num_neighbors, p_distance, add_self_loop, node_featurizer, edge_featurizer, canonical_atom_order, keep_dists, dist_field, explicit_hydrogens, num_virtual_nodes) + +class ToGraph: + r"""An abstract class for writing graph constructors.""" + def __call__(self, data_obj): + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ + '()' + +class MolToBigraph(ToGraph): + """Convert RDKit molecule objects into bi-directed DGLGraphs and featurize for them. + + Parameters + ---------- + add_self_loop : bool + Whether to add self loops in DGLGraphs. Default to False. + node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict + Featurization for nodes like atoms in a molecule, which can be used to update + ndata for a DGLGraph. Default to None. + edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict + Featurization for edges like bonds in a molecule, which can be used to update + edata for a DGLGraph. Default to None. + canonical_atom_order : bool + Whether to use a canonical order of atoms returned by RDKit. Setting it + to true might change the order of atoms in the graph constructed. Default + to True. + explicit_hydrogens : bool + Whether to explicitly represent hydrogens as nodes in the graph. If True, + it will call rdkit.Chem.AddHs(mol). Default to False. + num_virtual_nodes : int + The number of virtual nodes to add. The virtual nodes will be connected to + all real nodes with virtual edges. If the returned graph has any node/edge + feature, an additional column of binary values will be used for each feature + to indicate the identity of virtual node/edges. The features of the virtual + nodes/edges will be zero vectors except for the additional column. Default to 0. + + Examples + -------- + >>> import torch + >>> from rdkit import Chem + >>> from dgllife.utils import MolToBigraph + + >>> # A custom node featurizer + >>> def featurize_atoms(mol): + >>> feats = [] + >>> for atom in mol.GetAtoms(): + >>> feats.append(atom.GetAtomicNum()) + >>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()} + + >>> # A custom edge featurizer + >>> def featurize_bonds(mol): + >>> feats = [] + >>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, + >>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] + >>> for bond in mol.GetBonds(): + >>> btype = bond_types.index(bond.GetBondType()) + >>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u) + >>> feats.extend([btype, btype]) + >>> return {'type': torch.tensor(feats).reshape(-1, 1).float()} + + >>> mol_to_g = MolToBigraph(node_featurizer=featurize_atoms, edge_featurizer=featurize_bonds) + >>> mol = Chem.MolFromSmiles('CCO') + >>> g = mol_to_g(mol) + >>> print(g.ndata['atomic']) + tensor([[6.], + [8.], + [6.]]) + >>> print(g.edata['type']) + tensor([[0.], + [0.], + [0.], + [0.]]) + """ + def __init__(self, + add_self_loop=False, + node_featurizer=None, + edge_featurizer=None, + canonical_atom_order=True, + explicit_hydrogens=False, + num_virtual_nodes=0): + self.add_self_loop = add_self_loop + self.node_featurizer = node_featurizer + self.edge_featurizer = edge_featurizer + self.canonical_atom_order = canonical_atom_order + self.explicit_hydrogens = explicit_hydrogens + self.num_virtual_nodes = num_virtual_nodes + + def __call__(self, mol): + """Construct graph for the molecule and featurize it. + + Parameters + ---------- + mol : rdkit.Chem.rdchem.Mol + RDKit molecule holder + + Returns + ------- + DGLGraph or None + Bi-directed DGLGraph for the molecule if :attr:`mol` is valid and None otherwise. + """ + return mol_to_bigraph(mol, + self.add_self_loop, + self.node_featurizer, + self.edge_featurizer, + self.canonical_atom_order, + self.explicit_hydrogens, + self.num_virtual_nodes) + +class SMILESToBigraph(ToGraph): + """Convert SMILES strings into bi-directed DGLGraphs and featurize for them. + + Parameters + ---------- + add_self_loop : bool + Whether to add self loops in DGLGraphs. Default to False. + node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict + Featurization for nodes like atoms in a molecule, which can be used to update + ndata for a DGLGraph. Default to None. + edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict + Featurization for edges like bonds in a molecule, which can be used to update + edata for a DGLGraph. Default to None. + canonical_atom_order : bool + Whether to use a canonical order of atoms returned by RDKit. Setting it + to true might change the order of atoms in the graph constructed. Default + to True. + explicit_hydrogens : bool + Whether to explicitly represent hydrogens as nodes in the graph. If True, + it will call rdkit.Chem.AddHs(mol). Default to False. + num_virtual_nodes : int + The number of virtual nodes to add. The virtual nodes will be connected to + all real nodes with virtual edges. If the returned graph has any node/edge + feature, an additional column of binary values will be used for each feature + to indicate the identity of virtual node/edges. The features of the virtual + nodes/edges will be zero vectors except for the additional column. Default to 0. + + Examples + -------- + >>> import torch + >>> from rdkit import Chem + >>> from dgllife.utils import SMILESToBigraph + + >>> # A custom node featurizer + >>> def featurize_atoms(mol): + >>> feats = [] + >>> for atom in mol.GetAtoms(): + >>> feats.append(atom.GetAtomicNum()) + >>> return {'atomic': torch.tensor(feats).reshape(-1, 1).float()} + + >>> # A custom edge featurizer + >>> def featurize_bonds(mol): + >>> feats = [] + >>> bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, + >>> Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] + >>> for bond in mol.GetBonds(): + >>> btype = bond_types.index(bond.GetBondType()) + >>> # One bond between atom u and v corresponds to two edges (u, v) and (v, u) + >>> feats.extend([btype, btype]) + >>> return {'type': torch.tensor(feats).reshape(-1, 1).float()} + + >>> smi_to_g = SMILESToBigraph(node_featurizer=featurize_atoms, + ... edge_featurizer=featurize_bonds) + >>> g = smi_to_g('CCO') + >>> print(g.ndata['atomic']) + tensor([[6.], + [8.], + [6.]]) + >>> print(g.edata['type']) + tensor([[0.], + [0.], + [0.], + [0.]]) + """ + def __init__(self, + add_self_loop=False, + node_featurizer=None, + edge_featurizer=None, + canonical_atom_order=True, + explicit_hydrogens=False, + num_virtual_nodes=0): + self.add_self_loop = add_self_loop + self.node_featurizer = node_featurizer + self.edge_featurizer = edge_featurizer + self.canonical_atom_order = canonical_atom_order + self.explicit_hydrogens = explicit_hydrogens + self.num_virtual_nodes = num_virtual_nodes + + def __call__(self, smiles): + """Construct graph for the molecule and featurize it. + + Parameters + ---------- + smiles : str + SMILES string. + + Returns + ------- + DGLGraph or None + Bi-directed DGLGraph for the molecule if :attr:`smiles` is valid and None otherwise. + """ + return smiles_to_bigraph(smiles, + self.add_self_loop, + self.node_featurizer, + self.edge_featurizer, + self.canonical_atom_order, + self.explicit_hydrogens, + self.num_virtual_nodes) diff --git a/tests/data/test_new_dataset.py b/tests/data/test_new_dataset.py index 3ff6b1e8..a858c5b1 100644 --- a/tests/data/test_new_dataset.py +++ b/tests/data/test_new_dataset.py @@ -11,6 +11,7 @@ from dgllife.utils.featurizers import * from dgllife.utils.mol_to_graph import * from joblib import cpu_count +from python.dgllife.utils.mol_to_graph import MolToBigraph def test_data_frame1(): data = [['CCO', 0, 1], ['CO', 2, 3]] @@ -97,8 +98,9 @@ def test_mol_csv(): def test_unlabeled_smiles(): smiles = ['CCO', 'CO'] - dataset = UnlabeledSMILES(smiles, node_featurizer=CanonicalAtomFeaturizer(), - edge_featurizer=CanonicalBondFeaturizer()) + mol_to_g = MolToBigraph(node_featurizer=CanonicalAtomFeaturizer(), + edge_featurizer=CanonicalBondFeaturizer()) + dataset = UnlabeledSMILES(smiles, mol_to_graph=mol_to_g) assert len(dataset) == 2 smiles, graph = dataset[0] assert 'h' in graph.ndata