Skip to content

Commit

Permalink
Merge pull request #233 from RaulPPelaez/fix_module
Browse files Browse the repository at this point in the history
Fix trying to detach a float number in LNNP
  • Loading branch information
stefdoerr authored Oct 23, 2023
2 parents 2cc5395 + 51f666c commit 3ae06ce
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def _compute_losses(self, y, neg_y, batch, loss_fn, stage):
# Returns:
# loss_y: loss for the predicted value
# loss_neg_y: loss for the predicted negative derivative
loss_y, loss_neg_y = 0.0, 0.0
loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor(
0.0, device=self.device
)
loss_name = loss_fn.__name__
if self.hparams.derivative and "neg_dy" in batch:
loss_neg_y = loss_fn(neg_y, batch.neg_dy)
Expand Down

0 comments on commit 3ae06ce

Please sign in to comment.