Skip to content

Commit

Permalink
Merge pull request #332 from RaulPPelaez/maceds
Browse files Browse the repository at this point in the history
MACE-OFF dataset
  • Loading branch information
stefdoerr authored Jul 9, 2024
2 parents c800af1 + 191f454 commit 6c42c8b
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 0 deletions.
59 changes: 59 additions & 0 deletions examples/TensorNet-MACEOFF.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
activation: silu
aggr: add
atom_filter: -1
batch_size: 16
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 10.0
dataset: MACEOFF
dataset_arg:
max_gradient: 50.94
dataset_root: ~/data
derivative: true
early_stopping_patience: 50
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 128
energy_files: null
equivariance_invariance_group: O(3)
y_weight: 1.0
force_files: null
neg_dy_weight: 10.0
gradient_clipping: 100.0
inference_batch_size: 16
load_model: null
log_dir: logs/
lr: 0.0001
lr_factor: 0.5
lr_min: 1.0e-08
lr_patience: 5
lr_warmup_steps: 500
max_num_neighbors: 128
max_z: 128
model: tensornet
ngpus: -1
num_epochs: 500
num_layers: 2
num_nodes: 1
num_rbf: 64
num_workers: 4
output_model: Scalar
precision: 32
prior_model: null
rbf_type: expnorm
redirect: false
reduce_op: add
save_interval: 10
splits: null
seed: 1
standardize: false
test_interval: 10
test_size: null
train_size: 0.8
trainable_rbf: false
val_size: 0.1
weight_decay: 0.0
box_vecs: null
charge: false
spin: false
2 changes: 2 additions & 0 deletions torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .qm9q import QM9q
from .spice import SPICE
from .genentech import GenentechTorsions
from .maceoff import MACEOFF

__all__ = [
"Ace",
Expand All @@ -47,4 +48,5 @@
"SPICE",
"Tripeptides",
"WaterBox",
"MACEOFF",
]
138 changes: 138 additions & 0 deletions torchmdnet/datasets/maceoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import hashlib
from ase.data import atomic_numbers
import numpy as np
import os
import torch as pt
from torchmdnet.datasets.memdataset import MemmappedDataset
from torch_geometric.data import Data, download_url
import tarfile
import logging
import re
from tqdm import tqdm


def parse_maceoff_tar(tar_file):
energy_re = re.compile("energy=(\S+)")
with tarfile.open(tar_file, "r:gz") as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
if f is None:
continue
n_atoms = None
counter = 0
positions = []
numbers = []
forces = []
energy = None
for line in f:
line = line.decode("utf-8").strip()
if n_atoms is None:
n_atoms = int(line)
positions = []
numbers = []
forces = []
energy = None
counter = 1
continue
if counter == 1:
props = line
energy = float(energy_re.search(props).group(1))
counter = 2
continue
el, x, y, z, fx, fy, fz, _, _, _ = line.split()
numbers.append(atomic_numbers[el])
positions.append([float(x), float(y), float(z)])
forces.append([float(fx), float(fy), float(fz)])
counter += 1
if counter == n_atoms + 2:
n_atoms = None
yield energy, numbers, positions, forces


class MACEOFF(MemmappedDataset):
"""
MACEOFF dataset from MACE-OFF23: Transferable Machine Learning Force Fields for Organic Molecules, Kovacs et.al. https://arxiv.org/abs/2312.15211
This dataset consists of arounf 100K conformations with 95% of them coming from SPICE and augmented with conformations from QMugs, COMP6 and clusters of water carved out of MD simulations of liquid water.
From the repository:
The core of the training set is the SPICE dataset. 95% of the data were used for training and validation, and 5% for testing. The MACE-OFF23 model is trained to reproduce the energies and forces computed at the ωB97M-D3(BJ)/def2-TZVPPD level of quantum mechanics, as implemented in the PSI4 software. We have used a subset of SPICE that contains the ten chemical elements H, C, N, O, F, P, S, Cl, Br, and I, and has a neutral formal charge. We have also removed the ion pairs subset. Overall, we used about 85% of the full SPICE dataset.
Contains energy and force data in units of eV and eV/Angstrom
"""

VERSIONS = {
"1.0": {
"url": "https://api.repository.cam.ac.uk/server/api/core/bitstreams/b185b5ab-91cf-489a-9302-63bfac42824a/content",
"file": "train_large_neut_no_bad_clean.tar.gz",
},
}

@property
def raw_dir(self):
return os.path.join(super().raw_dir, "maceoff", self.version)

@property
def raw_file_names(self):
return self.VERSIONS[self.version]["file"]

@property
def raw_url(self):
return f"{self.VERSIONS[self.version]['url']}"

def __init__(
self,
root=None,
transform=None,
pre_transform=None,
pre_filter=None,
version="1.0",
max_gradient=None,
):
arg_hash = f"{version}{max_gradient}"
arg_hash = hashlib.md5(arg_hash.encode()).hexdigest()
self.name = f"{self.__class__.__name__}-{arg_hash}"
self.version = str(version)
assert self.version in self.VERSIONS
self.max_gradient = max_gradient
super().__init__(
root,
transform,
pre_transform,
pre_filter,
properties=("y", "neg_dy"),
)

def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1
logging.info(f"Processing dataset {self.raw_file_names}")
for energy, numbers, positions, forces in tqdm(
parse_maceoff_tar(self.raw_paths[0]), desc="Processing conformations"
):
data = Data(
**dict(
z=pt.tensor(np.array(numbers), dtype=pt.long),
pos=pt.tensor(positions, dtype=pt.float32),
y=pt.tensor(energy, dtype=pt.float64).view(1, 1),
neg_dy=pt.tensor(forces, dtype=pt.float32),
)
)
assert data.y.shape == (1, 1)
assert data.z.shape[0] == data.pos.shape[0]
assert data.neg_dy.shape[0] == data.pos.shape[0]
# Skip samples with large forces
if self.max_gradient:
if data.neg_dy.norm(dim=1).max() > float(self.max_gradient):
continue
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
yield data

def download(self):
download_url(self.raw_url, self.raw_dir, filename=self.raw_file_names)

0 comments on commit 6c42c8b

Please sign in to comment.