Skip to content

Commit

Permalink
Merge pull request #281 from RaulPPelaez/datasets
Browse files Browse the repository at this point in the history
Clean up some datasets
  • Loading branch information
RaulPPelaez authored Feb 14, 2024
2 parents 417f8a0 + 7063af0 commit c0edfed
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 57 deletions.
99 changes: 51 additions & 48 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,29 @@ def raw_url(self):
def raw_file_names(self):
raise NotImplementedError

def compute_reference_energy(self, atomic_numbers):
atomic_numbers = np.array(atomic_numbers)
energy = sum(self._ELEMENT_ENERGIES[z] for z in atomic_numbers)
return energy * ANIBase.HARTREE_TO_EV

def get_atomref(self, max_z=100):
raise NotImplementedError()
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

def __init__(
self,
root,
transform=None,
pre_transform=None,
pre_filter=None,
properties=("y", "neg_dy"),
):
self.name = self.__class__.__name__
super().__init__(
root,
transform,
pre_transform,
pre_filter,
remove_ref_energy=True,
properties=("y", "neg_dy"),
properties=properties,
)

def filter_and_pre_transform(self, data):
Expand All @@ -82,13 +82,28 @@ def filter_and_pre_transform(self, data):

class ANI1(ANIBase):
__doc__ = ANIBase.__doc__
# Avoid sphinx from documenting this
_ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
8: -75.0362229210,
} #::meta private:
}

def __init__(
self,
root,
transform=None,
pre_transform=None,
pre_filter=None,
):
self.name = self.__class__.__name__
super().__init__(
root,
transform,
pre_transform,
pre_filter,
properties=("y",),
)

@property
def raw_url(self):
Expand Down Expand Up @@ -134,16 +149,6 @@ def sample_iter(self, mol_ids=False):
if data := self.filter_and_pre_transform(data):
yield data

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
refs = pt.zeros(max_z)
refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C
refs[7] = -54.5680045287 * self.HARTREE_TO_EV # N
refs[8] = -75.0362229210 * self.HARTREE_TO_EV # O

return refs.view(-1, 1)


class ANI1XBase(ANIBase):
@property
Expand All @@ -159,31 +164,15 @@ def download(self):
assert len(self.raw_paths) == 1
os.rename(file, self.raw_paths[0])

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
warnings.warn("Atomic references from the ANI-1 dataset are used!")

refs = pt.zeros(max_z)
refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C
refs[7] = -54.5680045287 * self.HARTREE_TO_EV # N
refs[8] = -75.0362229210 * self.HARTREE_TO_EV # O

return refs.view(-1, 1)


class ANI1X(ANI1XBase):
__doc__ = ANIBase.__doc__
_ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
8: -75.0362229210,
1: -0.600952980000,
6: -38.08316124000,
7: -54.70775770000,
8: -75.19446356000,
}
"""
:meta private:
"""

def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1
Expand Down Expand Up @@ -223,6 +212,28 @@ def sample_iter(self, mol_ids=False):

class ANI1CCX(ANI1XBase):
__doc__ = ANIBase.__doc__
_ELEMENT_ENERGIES = {
1: -0.5991501324919538,
6: -38.03750806057356,
7: -54.67448347695333,
8: -75.16043537275567,
}

def __init__(
self,
root,
transform=None,
pre_transform=None,
pre_filter=None,
):
self.name = self.__class__.__name__
super().__init__(
root,
transform,
pre_transform,
pre_filter,
properties=("y",),
)

def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1
Expand Down Expand Up @@ -319,11 +330,3 @@ def sample_iter(self, mol_ids=False):

if data := self.filter_and_pre_transform(data):
yield data

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)
5 changes: 0 additions & 5 deletions torchmdnet/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,3 @@ def get_atomref(self, max_z=100):
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()
16 changes: 15 additions & 1 deletion torchmdnet/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class Custom(Dataset):
forceglob (string, optional): Glob path for force files. Stored as "neg_dy".
(default: :obj:`None`)
preload_memory_limit (int, optional): If the dataset is smaller than this limit (in MB), preload it into CPU memory.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
pre_transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before being saved to disk.
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean value,
indicating whether the data object should be included in the final
dataset.
Example:
>>> data = Custom(coordglob="coords_files*npy", embedglob="embed_files*npy")
Expand All @@ -45,8 +56,11 @@ def __init__(
energyglob=None,
forceglob=None,
preload_memory_limit=1024,
transform=None,
pre_transform=None,
pre_filter=None,
):
super(Custom, self).__init__()
super().__init__(None, transform, pre_transform, pre_filter)
assert energyglob is not None or forceglob is not None, (
"Either energies, forces or both must " "be specified as the target"
)
Expand Down
23 changes: 20 additions & 3 deletions torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ class MemmappedDataset(Dataset):
- :obj:`name.pq.mmap`: Partial charges of all the atoms.
- :obj:`name.dp.mmap`: Dipole moment of each conformation.
Args:
root (str): Root directory where the dataset should be stored.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
pre_transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before being saved to disk.
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean value,
indicating whether the data object should be included in the final
dataset.
remove_ref_energy (bool, optional): If set to :obj:`True`, the reference
energy will be subtracted from the energy of each conformation before
returning it.
properties (tuple of str, optional): The properties to include in the
dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`q`,
:obj:`pq`, and :obj:`dp`.
"""

def __init__(
Expand Down Expand Up @@ -92,9 +110,8 @@ def processed_paths_dict(self):
)
}

@staticmethod
def compute_reference_energy(self):
raise NotImplementedError
def compute_reference_energy(self, atomic_numbers):
return self.get_atomref()[atomic_numbers].sum()

def sample_iter(self, mol_ids=False):
raise NotImplementedError()
Expand Down

0 comments on commit c0edfed

Please sign in to comment.