Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ask about the issue of units during the training process #43

Closed
lingcon01 opened this issue Sep 24, 2024 · 9 comments
Closed

Ask about the issue of units during the training process #43

lingcon01 opened this issue Sep 24, 2024 · 9 comments

Comments

@lingcon01
Copy link

lingcon01 commented Sep 24, 2024

Hello, I would like to ask about energy prediction using SCHNET. According to the results you provided, under the ST split of the tiny dataset, is its MAE 1.17 Hartree or 1.17 * 10^(-2) Hartree? I tried to reproduce the result, and I found that the MAE on the test set is around 0.75, but I am not sure about the unit. Could you clarify this?

@BerAnton
Copy link
Contributor

Hello!
Could you please share run config file?

@lingcon01
Copy link
Author

lingcon01 commented Sep 25, 2024

##config
`train:
batch_size: 128
seed: 2021
epochs: 500
num_workers: 0
restore_path: null
data_path: /home/suqun/data/nablaDFT/energy/tiny/train
data_name: train_2k_v2_formation_energy_w_forces
save_path: checkpoints/train/md17
log_interval: 10
lr: 0.0002
factor: 0.9
patience: 30
min_lr: 0.000001
energy_weight: 1
force_weight: 0
weight_decay: 1e-16

test:
test_interval: 1
test_batch_size: 100
data_path: /home/suqun/data/nablaDFT/energy/tiny/test
data_name: test_2k_conformers_v2_formation_energy_w_forces
`

##loss
loss_energy = loss_l1(pred, label) loss_force = loss_l1(dy, pdy)

##data_process
for db_row in tqdm(db.select(), total=len(db)): z = torch.from_numpy(db_row.numbers.copy()).long() positions = torch.from_numpy(db_row.positions.copy()).float() y = torch.from_numpy(np.array(db_row.data["energy"])).float() forces = torch.from_numpy(np.array(db_row.data["forces"])).float() molecule_size = len(positions) samples.append(Data(z=z, pos=positions, y=y, dy=forces, molecule_size=molecule_size))

##schnet
class SchNet(torch.nn.Module): r""" The re-implementation for SchNet from the "SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" https://arxiv.org/abs/1706.08566_ paper under the 3DGN gramework from "Spherical Message Passing for 3D Molecular Graphs" https://openreview.net/forum?id=givsRXsOt9r`_ paper.

    Args:
        energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
        num_layers (int, optional): The number of layers. (default: :obj:`6`)
        hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`)
        out_channels (int, optional): Output embedding size. (default: :obj:`1`)
        num_filters (int, optional): The number of filters to use. (default: :obj:`128`)
        num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`)
        cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`).
"""

def __init__(self, energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, out_channels=1,
             num_filters=128, num_gaussians=50):
    super(SchNet, self).__init__()

    self.energy_and_force = energy_and_force
    self.cutoff = cutoff
    self.num_layers = num_layers
    self.hidden_channels = hidden_channels
    self.out_channels = out_channels
    self.num_filters = num_filters
    self.num_gaussians = num_gaussians

    self.init_v = Embedding(100, hidden_channels)
    self.dist_emb = emb(0.0, cutoff, num_gaussians)

    self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters) for _ in range(num_layers)])

    self.update_es = torch.nn.ModuleList([
        update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers)])

    self.update_u = update_u(hidden_channels, out_channels)

    self.reset_parameters()

def reset_parameters(self):
    self.init_v.reset_parameters()
    for update_e in self.update_es:
        update_e.reset_parameters()
    for update_v in self.update_vs:
        update_v.reset_parameters()
    self.update_u.reset_parameters()

def forward(self, batch_data, mean, std):
    z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
    if self.energy_and_force:
        pos.requires_grad_()

    edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
    row, col = edge_index
    dist = (pos[row] - pos[col]).norm(dim=-1)
    dist_emb = self.dist_emb(dist)

    v = self.init_v(z)

    for update_e, update_v in zip(self.update_es, self.update_vs):
        e = update_e(v, dist, dist_emb, edge_index)
        v = update_v(v, e, edge_index)
    u = self.update_u(v, batch)

    # calculate_forces:
    grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(u)]
    forces = -torch.autograd.grad(
        outputs=[u],  # [n_graphs, ]
        inputs=[pos],  # [n_nodes, 3]
        grad_outputs=grad_outputs,
        retain_graph=True,  # Make sure the graph is not destroyed during training
        create_graph=True,  # Create graph for second derivative
        allow_unused=True,  # For complete dissociation turn to true
    )[
        0
    ]  # [n_nodes, 3]

    return u, forces`

@BerAnton
Copy link
Contributor

First of all, model parameters differs from nablaDFT's SchNet.
Also you use conformers test split, not structures (ST).

Preferred way to reproduce results is to use run.py script with desired checkpoint and test data split.
Please refer to run configuration README for more details.

If you want to reproduce SchNet trained on tiny data split for structures (ST) test use this command:

python run.py --config schnet_test.yaml

schnet_test.yaml:

# Global variables
name: SchNet
dataset_name: dataset_test_structures
max_steps: 1000000
job_type: test
pretrained: SchNet_train_tiny # name of pretrained split or 'null'
ckpt_path: null # path to checkpoint for training resume or test run

# Datamodule parameters
root: ./datasets/nablaDFT/${.job_type}
batch_size: 32
num_workers: 8

# Devices
devices: [0]

# configs
defaults:
  - _self_
  - datamodule: nablaDFT_ase_test.yaml  # dataset config
  - model: schnet.yaml  # model config
  - callbacks: callbacks_spk.yaml  # pl callbacks config
  - loggers: wandb.yaml  # pl loggers config
  - trainer: test.yaml  # trainer config

# need this to set working dir as current dir
hydra:
  output_subdir: null
  run:
    dir: .
original_work_dir: ${hydra:runtime.cwd}

seed: 23

@lingcon01
Copy link
Author

Thanks for your explanation and response, I understand. I have another question to ask. I used PygnablaDFT to read the dataset test_2k_conformers_v2_formation_energy_w_forces_test and saved it into a .pt file. Are the energy values in this dataset in Hartree units?

a = torch.load('test_2k_conformers_v2_formation_energy_w_forces_test.pt')
print(a[0].y)
tensor([-6.0566, -6.0559, -6.0675, ..., -7.9457, -7.4641, -7.4602])

@KuzmaKhrabrov
Copy link
Contributor

All energy units in energy databases are Hartree units.

@lingcon01
Copy link
Author

Thanks for your response! In the training process involving multiple partitioning methods for energy and force, is the loss calculation typically energy * 1 + force * 1 ?

@KuzmaKhrabrov
Copy link
Contributor

In general, these are the hyperparameters of the training pipeline. For instance, we used 1:100 for training SchNet, PaiNN, GemNet. 1:1 for DimeNet and 2:100 for EquiformerV2

@lingcon01
Copy link
Author

Thanks for your response! I would like to ask if this dataset contains a force label. "dataset_train_full": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_full_v2_formation_energy.db",

@KuzmaKhrabrov
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants