Skip to content

Commit

Permalink
Merge pull request #93 from datamol-io/fix/collate_fn
Browse files Browse the repository at this point in the history
Fix/collate fn
  • Loading branch information
maclandrol authored Dec 19, 2023
2 parents f442be0 + 5b97ac2 commit 08a9274
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
matrix:
python-version: ["3.9"]
os: ["ubuntu-latest"] #, "macos-latest", "windows-latest"]
pytorch-version: ["1.12"]
pytorch-version: ["1.13"]

runs-on: ${{ matrix.os }}
timeout-minutes: 30
Expand Down
11 changes: 6 additions & 5 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- dglteam

dependencies:
- python >=3.8
- python >3.8
- pip
- tqdm
- pyyaml
Expand Down Expand Up @@ -33,18 +33,19 @@ dependencies:
- mordredcommunity

# ML
- pytorch =1.12
- pytorch >=1.13
- scikit-learn
- fcd_torch

# Optional: featurizers
- dgl
- dgllife
- dgl >=1.1.1
- dgllife >=0.3.2
- graphormer-pretrained >=0.2.3
- transformers
- tokenizers <0.13.2
- sentencepiece
- biotite # required for ESM models
- biotite # required for ESM model
- pytorch_geometric >=2.4.0

# Optional: viz
- nglview
Expand Down
27 changes: 24 additions & 3 deletions molfeat/trans/graph/adj.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from functools import partial
from typing import Any
from typing import Optional
from typing import Callable
from typing import List
from typing import Union
from typing import Sequence
from typing import TYPE_CHECKING

import torch
import datamol as dm
Expand All @@ -27,10 +30,18 @@
if requires.check("dgllife"):
from dgllife import utils as dgllife_utils


if requires.check("torch_geometric"):
from torch_geometric.data import Data
from torch_geometric.loader.dataloader import Collater

if TYPE_CHECKING:
from torch_geometric.data import Dataset as PygDataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
else:
PygDataset, BaseData, DatasetAdapter = Any, Any, Any


class GraphTransformer(MoleculeTransformer):
"""
Expand Down Expand Up @@ -659,7 +670,9 @@ def patch_feats(*args, **kwargs):


class PYGGraphTransformer(AdjGraphTransformer):
"""Graph transformer for the PYG models"""
"""
Graph transformer for the PYG models
"""

def _graph_featurizer(self, mol: dm.Mol):
# we have used bond_calculator, therefore we need to
Expand Down Expand Up @@ -727,23 +740,31 @@ def transform(self, mols: List[Union[dm.Mol, str]], **kwargs):

def get_collate_fn(
self,
dataset: Optional[Union[PygDataset, Sequence[BaseData], DatasetAdapter]] = None,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
return_pair: Optional[bool] = True,
**kwargs,
):
"""
Get collate function for pyg graphs
Get collate function for pyg graphs.
Note: The `collate_fn` is not required when using `torch_geometric.loader.dataloader.DataLoader`.
Args:
dataset: The dataset from which to load the data and apply the collate function.
This is required if the dataset is <torch_geometric.data.on_disk_dataset.OnDiskDataset>.
follow_batch: Creates assignment batch vectors for each key in the list. (default: :obj:`None`)
exclude_keys: Will exclude each key in the list. (default: :obj:`None`)
return_pair: whether to return a pair of X,y or a databatch (default: :obj:`True`)
Returns:
Collated samples.
See Also:
<torch_geometric.loader.dataloader.Collator>
<torch_geometric.loader.dataloader.DataLoader>
"""
collator = Collater(follow_batch=follow_batch, exclude_keys=exclude_keys)
collator = Collater(dataset=dataset, follow_batch=follow_batch, exclude_keys=exclude_keys)
return partial(self._collate_batch, collator=collator, return_pair=return_pair)

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dynamic = ["version"]
authors = [{ name = "Emmanuel Noutahi", email = "[email protected]" }]
readme = "README.md"
license = { text = "Apache" }
requires-python = ">=3.8"
requires-python = ">3.8"
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
Expand Down Expand Up @@ -39,7 +39,7 @@ dependencies = [
"pandas",
"numpy",
"scipy",
"torch",
"torch>=1.13",
"datamol >=0.8.0",
"pyyaml",
"platformdirs",
Expand All @@ -57,7 +57,7 @@ dependencies = [
]

[project.optional-dependencies]
dgl = ["dgl", "dgllife"]
dgl = ["dgl>=1.1.1", "dgllife>=0.3.2"]

graphormer = ["graphormer-pretrained"]

Expand All @@ -67,6 +67,8 @@ fcd = ["fcd_torch"]

viz = ["nglview", "ipywidgets"]

pyg = ["pytorch_geometric >=2.4.0"]

all = [
"dgl",
"dgllife",
Expand All @@ -76,6 +78,7 @@ all = [
"fcd_torch",
"nglview",
"ipywidgets",
"pytorch_geometric >=2.4.0"
]

test = ["pytest >=6.0","pytest-dotenv", "pytest-cov", "pytest-xdist", "black >=22", "ruff"]
Expand Down

0 comments on commit 08a9274

Please sign in to comment.