Skip to content

Commit

Permalink
Merge pull request #259 from torchmd/genentech_torsions
Browse files Browse the repository at this point in the history
Genentech torsions
  • Loading branch information
RaulPPelaez authored Jan 26, 2024
2 parents 3e11258 + 16af1d3 commit 05d3d2c
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .qm9 import QM9
from .qm9q import QM9q
from .spice import SPICE
from .genentech import GenentechTorsions

__all__ = [
"Ace",
Expand All @@ -36,6 +37,7 @@
"DrugBank",
"GDB07to09",
"GDB10to13",
"GenentechTorsions",
"HDF5",
"MD17",
"MD22",
Expand Down
130 changes: 130 additions & 0 deletions torchmdnet/datasets/genentech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import numpy as np
import os
import torch as pt
from torch_geometric.data import Data, download_url
from torchmdnet.datasets.memdataset import MemmappedDataset
from torchmdnet.utils import ATOMIC_NUMBERS


class GenentechTorsions(MemmappedDataset):
"""Dataset of torsion scans of small molecules.
This is a dataset consisting of torsion scans of small molecules.
Gas-phase geometries and energies are calculated with CCSD(T)/CBS theory.
By default we load the relative energies in the dataset which are relative to the minimum energy of the torsion scan.
References:
- https://pubs.acs.org/doi/10.1021/acs.jcim.6b00614
"""

KCALMOL_TO_EV = 0.0433641153087705

def __init__(
self,
root=None,
transform=None,
pre_transform=None,
pre_filter=None,
paths=None,
theory="CCSD_T_CBS_MP2",
energy_field="deltaE",
):
self.name = self.__class__.__name__
self.paths = str(paths)
self.theory = theory
self.energy_field = energy_field
super().__init__(
root,
transform,
pre_transform,
pre_filter,
remove_ref_energy=False,
properties=("y"),
)

@property
def raw_url(self):
return "https://github.com/Acellera/sellers/raw/main/ci6b00614_si_002.zip"

@property
def raw_file_names(self):
return [
"QM_MM_Gas_Phase_Torsion_Scan_Individual_Results_with_CCSD_T_CBS_baseline.sdf"
]

def download(self):
import zipfile

archive = download_url(self.raw_url, self.raw_dir)

with zipfile.ZipFile(archive, "r") as zip_ref:
zip_ref.extractall(self.raw_dir)
os.remove(archive)

def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1

with open(self.raw_paths[0]) as f:
molstart_count = 0
discard_molecule = False
deltaE = None
mol_id = None
num_atoms = None
scan_atoms = None
z = []
pos = []
for line in f:
if discard_molecule and not line.strip().startswith("$$$$"):
continue
if molstart_count >= 0 and molstart_count < 4:
molstart_count += 1
if molstart_count == 4: # On the 4th line we read atom counts
num_atoms = int(line.strip().split()[0])
molstart_count = -1 # Start atom/bond section
continue
if line.strip().startswith("$$$$"):
if not discard_molecule:
data = Data(
z=pt.tensor(z, dtype=pt.long),
pos=pt.tensor(np.vstack(pos), dtype=pt.float32),
y=pt.tensor(deltaE * self.KCALMOL_TO_EV, dtype=pt.float64),
mol_id=mol_id,
scan_atoms=scan_atoms,
)
yield data

molstart_count = 0
discard_molecule = False
deltaE = None
mol_id = None
num_atoms = None
scan_atoms = None
z = []
pos = []
continue

# Parsing the atom section
if num_atoms is not None:
num_atoms -= 1
if num_atoms >= 0:
pos_x, pos_y, pos_z, el = line.strip().split()[:4]
pos.append([float(pos_x), float(pos_y), float(pos_z)])
z.append(ATOMIC_NUMBERS[el])

# Parsing the SDF properties
if line.strip().startswith("> <MinMethod>"):
min_method = next(f).strip()
if min_method != self.theory:
discard_molecule = True
continue
if line.strip().startswith(f"> <{self.energy_field}>"):
deltaE = float(next(f).strip())
if line.strip().startswith("> <Number>"):
mol_id = int(next(f).strip())
if line.strip().startswith("> <ScanAtoms_1>"):
scan_atoms = map(int, next(f).strip().split())
124 changes: 124 additions & 0 deletions torchmdnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,127 @@
])
# fmt: on

ATOMIC_NUMBERS = {
"H": 1,
"He": 2,
"Li": 3,
"Be": 4,
"B": 5,
"C": 6,
"N": 7,
"O": 8,
"F": 9,
"Ne": 10,
"Na": 11,
"Mg": 12,
"Al": 13,
"Si": 14,
"P": 15,
"S": 16,
"Cl": 17,
"Ar": 18,
"K": 19,
"Ca": 20,
"Sc": 21,
"Ti": 22,
"V": 23,
"Cr": 24,
"Mn": 25,
"Fe": 26,
"Co": 27,
"Ni": 28,
"Cu": 29,
"Zn": 30,
"Ga": 31,
"Ge": 32,
"As": 33,
"Se": 34,
"Br": 35,
"Kr": 36,
"Rb": 37,
"Sr": 38,
"Y": 39,
"Zr": 40,
"Nb": 41,
"Mo": 42,
"Tc": 43,
"Ru": 44,
"Rh": 45,
"Pd": 46,
"Ag": 47,
"Cd": 48,
"In": 49,
"Sn": 50,
"Sb": 51,
"Te": 52,
"I": 53,
"Xe": 54,
"Cs": 55,
"Ba": 56,
"La": 57,
"Ce": 58,
"Pr": 59,
"Nd": 60,
"Pm": 61,
"Sm": 62,
"Eu": 63,
"Gd": 64,
"Tb": 65,
"Dy": 66,
"Ho": 67,
"Er": 68,
"Tm": 69,
"Yb": 70,
"Lu": 71,
"Hf": 72,
"Ta": 73,
"W": 74,
"Re": 75,
"Os": 76,
"Ir": 77,
"Pt": 78,
"Au": 79,
"Hg": 80,
"Tl": 81,
"Pb": 82,
"Bi": 83,
"Po": 84,
"At": 85,
"Rn": 86,
"Fr": 87,
"Ra": 88,
"Ac": 89,
"Th": 90,
"Pa": 91,
"U": 92,
"Np": 93,
"Pu": 94,
"Am": 95,
"Cm": 96,
"Bk": 97,
"Cf": 98,
"Es": 99,
"Fm": 100,
"Md": 101,
"No": 102,
"Lr": 103,
"Rf": 104,
"Db": 105,
"Sg": 106,
"Bh": 107,
"Hs": 108,
"Mt": 109,
"Ds": 110,
"Rg": 111,
"Cn": 112,
"Nh": 113,
"Fl": 114,
"Mc": 115,
"Lv": 116,
"Ts": 117,
"Og": 118,
}


def train_val_test_split(dset_len, train_size, val_size, test_size, seed, order=None):
assert (train_size is None) + (val_size is None) + (
Expand Down Expand Up @@ -224,6 +345,7 @@ def number(text):
class MissingEnergyException(Exception):
pass


def write_as_hdf5(files, hdf5_dataset):
"""Transform the input numpy files to hdf5 format compatible with the HDF5 Dataset class.
The input files to this function are the same as the ones required by the Custom dataset.
Expand All @@ -239,6 +361,7 @@ def write_as_hdf5(files, hdf5_dataset):
>>> write_as_hdf5(files, join(tmpdir, "test.hdf5"))
"""
import h5py

with h5py.File(hdf5_dataset, "w") as f:
for i in range(len(files["pos"])):
# Create a group for each file
Expand All @@ -255,6 +378,7 @@ def write_as_hdf5(files, hdf5_dataset):
force_data = np.load(files["neg_dy"][i], mmap_mode="r")
group.create_dataset("forces", data=force_data)


def deprecated_class(cls):
"""Decorator to mark classes as deprecated."""
orig_init = cls.__init__
Expand Down

0 comments on commit 05d3d2c

Please sign in to comment.