Skip to content

Commit

Permalink
Add dtype parameter to External (#274)
Browse files Browse the repository at this point in the history
* Add dtype parameter to External

* simplify

* cast box too

* Allow to pass dtype as a string

* Understand a 1D box as a diagonal one

* Understand box batching

* typo

* typo

* Update test

* Force static_shapes in cuda graph mode
Allow External to take kwargs for load_model

* Add some warnings
  • Loading branch information
RaulPPelaez committed Feb 19, 2024
1 parent 853d838 commit e79b9c0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 42 deletions.
70 changes: 33 additions & 37 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,45 @@
from utils import create_example_batch


def test_compare_forward():
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
z, pos, _ = create_example_batch(multiple_batches=False)
calc = External(checkpoint, z.unsqueeze(0))
model = load_model(checkpoint, derivative=True)

e_calc, f_calc = calc.calculate(pos, None)
e_pred, f_pred = model(z, pos)

assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred.unsqueeze(0))

@pytest.mark.parametrize("box", [None, torch.eye(3)])
def test_compare_forward_cuda_graph(box):
if not torch.cuda.is_available():
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
def test_compare_forward(box, use_cuda_graphs):
if use_cuda_graphs and not torch.cuda.is_available():
pytest.skip("CUDA not available")
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
args = {"model": "tensornet",
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
model = create_model(args).to(device="cuda")
args = {
"model": "tensornet",
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
device = "cpu" if not use_cuda_graphs else "cuda"
model = create_model(args).to(device=device)
z, pos, _ = create_example_batch(multiple_batches=False)
z = z.to("cuda")
pos = pos.to("cuda")
calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device="cuda")
calc_graph = External(checkpoint, z.unsqueeze(0), use_cuda_graph=True, device="cuda")
z = z.to(device)
pos = pos.to(device)
calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device=device)
calc_graph = External(
checkpoint, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device
)
calc.model = model
calc_graph.model = model
if box is not None:
box = (box * 2 * args["cutoff_upper"]).unsqueeze(0)
for _ in range(10):
e_calc, f_calc = calc.calculate(pos, box)
e_pred, f_pred = calc_graph.calculate(pos, box)
Expand Down
43 changes: 38 additions & 5 deletions torchmdnet/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torchmdnet.models.model import load_model
import warnings

# dict of preset transforms
tranforms = {
Expand Down Expand Up @@ -42,6 +43,10 @@ class External:
Whether to use CUDA graphs to speed up the calculation. Default: False
cuda_graph_warmup_steps : int, optional
Number of steps to run as warmup before recording the CUDA graph. Default: 12
dtype : torch.dtype or str, optional
Cast the input to this dtype if defined. If passed as a string it should be a valid torch dtype. Default: torch.float32
kwargs : dict, optional
Extra arguments to pass to the model when loading it.
"""

def __init__(
Expand All @@ -52,10 +57,28 @@ def __init__(
output_transform=None,
use_cuda_graph=False,
cuda_graph_warmup_steps=12,
dtype=torch.float32,
**kwargs,
):
if isinstance(netfile, str):
self.model = load_model(netfile, device=device, derivative=True)
extra_args = kwargs
if use_cuda_graph:
warnings.warn(
"CUDA graphs are enabled, setting static_shapes=True and check_errors=False"
)
extra_args["static_shapes"] = True
extra_args["check_errors"] = False
self.model = load_model(
netfile,
device=device,
derivative=True,
**extra_args,
)
elif isinstance(netfile, torch.nn.Module):
if kwargs:
warnings.warn(
"Warning: extra arguments are being ignored when passing a torch.nn.Module"
)
self.model = netfile
else:
raise ValueError(
Expand Down Expand Up @@ -87,6 +110,12 @@ def __init__(
self.forces = None
self.box = None
self.pos = None
if isinstance(dtype, str):
try:
dtype = getattr(torch, dtype)
except AttributeError:
raise ValueError(f"Unknown torch dtype {dtype}")
self.dtype = dtype

def _init_cuda_graph(self):
stream = torch.cuda.Stream()
Expand All @@ -101,7 +130,7 @@ def _init_cuda_graph(self):
self.embeddings, self.pos, self.batch, self.box
)

def calculate(self, pos, box = None):
def calculate(self, pos, box=None):
"""Calculate the energy and forces of the system.
Parameters
Expand All @@ -118,7 +147,9 @@ def calculate(self, pos, box = None):
forces : torch.Tensor
Forces on the atoms in the system.
"""
pos = pos.to(self.device).type(torch.float32).reshape(-1, 3)
pos = pos.to(self.device).to(self.dtype).reshape(-1, 3)
if box is not None:
box = box.to(self.device).to(self.dtype)
if self.use_cuda_graph:
if self.pos is None:
self.pos = (
Expand All @@ -128,10 +159,12 @@ def calculate(self, pos, box = None):
.requires_grad_(pos.requires_grad)
)
if self.box is None and box is not None:
self.box = box.clone().to(self.device).detach()
self.box = box.clone().to(self.device).to(self.dtype).detach()
if self.cuda_graph is None:
self._init_cuda_graph()
assert self.cuda_graph is not None, "CUDA graph is not initialized. This should not had happened."
assert (
self.cuda_graph is not None
), "CUDA graph is not initialized. This should not had happened."
with torch.no_grad():
self.pos.copy_(pos)
if box is not None:
Expand Down

0 comments on commit e79b9c0

Please sign in to comment.