Skip to content

Commit

Permalink
[Refactor] Re-implement some functional APIs as modules (#180)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
  • Loading branch information
mufeili authored Jun 11, 2022
1 parent 9784e62 commit a6f4596
Show file tree
Hide file tree
Showing 28 changed files with 392 additions and 208 deletions.
3 changes: 3 additions & 0 deletions docs/source/api/utils.mols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ three common graph constructions:
dgllife.utils.k_nearest_neighbors
dgllife.utils.mol_to_nearest_neighbor_graph
dgllife.utils.smiles_to_nearest_neighbor_graph
dgllife.utils.ToGraph
dgllife.utils.MolToBigraph
dgllife.utils.SMILESToBigraph

Featurization for Molecules
---------------------------
Expand Down
13 changes: 7 additions & 6 deletions examples/property_prediction/MTL/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from argparse import ArgumentParser
from dgllife.data import MoleculeCSVDataset
from dgllife.utils import smiles_to_bigraph, RandomSplitter
from dgllife.utils import SMILESToBigraph, RandomSplitter

from configure import configs
from run import main
Expand Down Expand Up @@ -67,13 +67,14 @@
# Setup for experiments
mkdir_p(args['result_path'])

node_featurizer = atom_featurizer
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True)
df = pd.read_csv(args['csv_path'])

smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer,
edge_featurizer=edge_featurizer)

dataset = MoleculeCSVDataset(
df, partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
df, smiles_to_g,
smiles_column=args['smiles_column'],
cache_file_path=args['result_path'] + '/graph.bin',
task_names=args['tasks']
Expand All @@ -84,4 +85,4 @@
dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=0)

main(args, node_featurizer, edge_featurizer, train_set, val_set, test_set)
main(args, atom_featurizer, edge_featurizer, train_set, val_set, test_set)
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
import torch

from dgllife.data import UnlabeledSMILES
from dgllife.utils import mol_to_bigraph
from functools import partial
from dgllife.utils import MolToBigraph
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import mkdir_p, collate_molgraphs_unlabeled, load_model, predict, init_featurizer

def main(args):
dataset = UnlabeledSMILES(args['smiles'], node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
mol_to_graph=partial(mol_to_bigraph, add_self_loop=True))
mol_to_g = MolToBigraph(add_self_loop=True,
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'])
dataset = UnlabeledSMILES(args['smiles'], mol_to_graph=mol_to_g)
dataloader = DataLoader(dataset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_unlabeled, num_workers=args['num_workers'])
model = load_model(args).to(args['device'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
import torch

from dgllife.data import UnlabeledSMILES
from dgllife.utils import mol_to_bigraph
from functools import partial
from dgllife.utils import MolToBigraph
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import mkdir_p, collate_molgraphs_unlabeled, load_model, predict, init_featurizer

def main(args):
dataset = UnlabeledSMILES(args['smiles'], node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
mol_to_graph=partial(mol_to_bigraph, add_self_loop=True))
mol_to_g = MolToBigraph(add_self_loop=True,
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'])
dataset = UnlabeledSMILES(args['smiles'], mol_to_graph=mol_to_g)
dataloader = DataLoader(dataset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_unlabeled, num_workers=args['num_workers'])
model = load_model(args).to(args['device'])
Expand Down
9 changes: 4 additions & 5 deletions examples/property_prediction/csv_data_configuration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import torch.nn.functional as F

from dgllife.data import MoleculeCSVDataset
from dgllife.utils import smiles_to_bigraph, ScaffoldSplitter, RandomSplitter, mol_to_bigraph
from functools import partial
from dgllife.utils import SMILESToBigraph, ScaffoldSplitter, RandomSplitter

def init_featurizer(args):
"""Initialize node/edge featurizer
Expand Down Expand Up @@ -60,10 +59,10 @@ def init_featurizer(args):
return args

def load_dataset(args, df):
smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'])
dataset = MoleculeCSVDataset(df=df,
smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
smiles_to_graph=smiles_to_g,
smiles_column=args['smiles_column'],
cache_file_path=args['result_path'] + '/graph.bin',
task_names=args['task_names'],
Expand Down
41 changes: 12 additions & 29 deletions examples/property_prediction/moleculenet/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import torch.nn as nn

from dgllife.model import load_pretrained
from dgllife.utils import smiles_to_bigraph, EarlyStopping, Meter
from functools import partial
from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph
from torch.optim import Adam
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -160,59 +159,43 @@ def main(args, exp_config, train_set, val_set, test_set):

args = init_featurizer(args)
mkdir_p(args['result_path'])
smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'])
if args['dataset'] == 'MUV':
from dgllife.data import MUV
dataset = MUV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = MUV(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'BACE':
from dgllife.data import BACE
dataset = BACE(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = BACE(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'BBBP':
from dgllife.data import BBBP
dataset = BBBP(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = BBBP(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'ClinTox':
from dgllife.data import ClinTox
dataset = ClinTox(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = ClinTox(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'SIDER':
from dgllife.data import SIDER
dataset = SIDER(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = SIDER(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'ToxCast':
from dgllife.data import ToxCast
dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = ToxCast(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'HIV':
from dgllife.data import HIV
dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = HIV(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'PCBA':
from dgllife.data import PCBA
dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = PCBA(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'Tox21':
from dgllife.data import Tox21
dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = Tox21(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
else:
raise ValueError('Unexpected dataset: {}'.format(args['dataset']))
Expand Down
17 changes: 6 additions & 11 deletions examples/property_prediction/moleculenet/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import torch.nn as nn

from dgllife.model import load_pretrained
from dgllife.utils import smiles_to_bigraph, EarlyStopping, Meter
from functools import partial
from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph
from torch.optim import Adam
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -157,23 +156,19 @@ def main(args, exp_config, train_set, val_set, test_set):

args = init_featurizer(args)
mkdir_p(args['result_path'])
smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'])
if args['dataset'] == 'FreeSolv':
from dgllife.data import FreeSolv
dataset = FreeSolv(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = FreeSolv(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'Lipophilicity':
from dgllife.data import Lipophilicity
dataset = Lipophilicity(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = Lipophilicity(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
elif args['dataset'] == 'ESOL':
from dgllife.data import ESOL
dataset = ESOL(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'],
dataset = ESOL(smiles_to_graph=smiles_to_g,
n_jobs=1 if args['num_workers'] == 0 else args['num_workers'])
else:
raise ValueError('Unexpected dataset: {}'.format(args['dataset']))
Expand Down
2 changes: 1 addition & 1 deletion examples/property_prediction/pretrain_gnns/chem/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This is a DGL implementation of the following paper based on PyTorch.
This paper purposed an attribute masking pre-training method. It randomly masks input node/edge attributes by replacing them with special masked indicators, then the GNN will predict those attributes based on neighboring structure.

``` bash
python pretrain_masking.py --output_model_file OUTPUT_MODEL_FILE
python pretrain_masking.py --output_model_file OUTPUT_MODEL_FILE
```
The self-supervised pre-training model will be found in `OUTPUT_MODEL_FILE` after training (default filename: pretrain_masking.pth).

Expand Down
41 changes: 12 additions & 29 deletions examples/property_prediction/pretrain_gnns/chem/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# https://github.com/awslabs/dgl-lifesci/blob/master/examples/property_prediction/moleculenet/classification.py

import argparse
from functools import partial
import numpy as np
import torch
import torch.nn as nn

from dgllife.utils import PretrainAtomFeaturizer
from dgllife.utils import PretrainBondFeaturizer
from dgllife.utils import smiles_to_bigraph, Meter, EarlyStopping
from dgllife.utils import Meter, EarlyStopping, SMILESToBigraph
from dgllife.model.model_zoo.gin_predictor import GINPredictor

from torch.utils.data import DataLoader
Expand Down Expand Up @@ -155,69 +154,53 @@ def main(args, dataset, device):

atom_featurizer = PretrainAtomFeaturizer()
bond_featurizer = PretrainBondFeaturizer()
smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer)

if args.dataset == 'MUV':
from dgllife.data import MUV

dataset = MUV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = MUV(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'BACE':
from dgllife.data import BACE

dataset = BACE(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = BACE(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'BBBP':
from dgllife.data import BBBP

dataset = BBBP(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = BBBP(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'ClinTox':
from dgllife.data import ClinTox

dataset = ClinTox(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = ClinTox(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'SIDER':
from dgllife.data import SIDER

dataset = SIDER(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = SIDER(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'ToxCast':
from dgllife.data import ToxCast

dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = ToxCast(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'HIV':
from dgllife.data import HIV

dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = HIV(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'PCBA':
from dgllife.data import PCBA

dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = PCBA(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
elif args.dataset == 'Tox21':
from dgllife.data import Tox21

dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
dataset = Tox21(smiles_to_graph=smiles_to_g,
n_jobs=1 if args.num_workers == 0 else args.num_workers)
else:
raise ValueError('Unexpected dataset: {}'.format(args.dataset))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
import random

import dgl
from dgllife.utils import PretrainAtomFeaturizer
from dgllife.utils import PretrainBondFeaturizer
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import PretrainAtomFeaturizer, PretrainBondFeaturizer, SMILESToBigraph
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive

from dgllife.model.gnn.gin import GIN
Expand Down Expand Up @@ -184,10 +182,10 @@ def main():

atom_featurizer = PretrainAtomFeaturizer()
bond_featurizer = PretrainBondFeaturizer()
smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer)
dataset = PretrainDataset(data=data,
smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
smiles_to_graph=smiles_to_g,
smiles_column='smiles',
task='masking')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
import argparse
import pickle
import tqdm
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import dgl
from dgllife.utils import PretrainAtomFeaturizer
from dgllife.utils import PretrainBondFeaturizer
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import PretrainAtomFeaturizer, PretrainBondFeaturizer, SMILESToBigraph
from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive
from dgllife.model.model_zoo.gin_predictor import GINPredictor

Expand Down Expand Up @@ -148,11 +145,10 @@ def main():

atom_featurizer = PretrainAtomFeaturizer()
bond_featurizer = PretrainBondFeaturizer()
smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer)
dataset = PretrainDataset(data=data,
smiles_to_graph=partial(
smiles_to_bigraph, add_self_loop=True),
node_featurizer=atom_featurizer,
edge_featurizer=bond_featurizer,
smiles_to_graph=smiles_to_g,
task='supervised')

train_dataloader = DataLoader(dataset=dataset,
Expand Down
Loading

0 comments on commit a6f4596

Please sign in to comment.