Skip to content

Commit

Permalink
Make TensorNet compatible with TorchScript (#186)
Browse files Browse the repository at this point in the history
* Change some lines incompatible with jit script

* Remove some empty lines

* Fix typo

* Include an assert to appease TorchScript

* Change a range loop to an enumerate

* Add test for skewtensor function

* Small changes from merge

* Update test

* Update vector_to_skewtensor

* Remove some parenthesis

* Small changes

* Remove skewtensor test

* Annotate types in Atomref

* Simplify a couple of operations

* Check also derivative in torchscript test

* Type annotate forward LLNP

* Try double backward in the TorchScript test

* Change test name

* Annotate forward

* Remove unused variables

* Remove unnecessary enumerates

* Add TorchScript GPU tests
  • Loading branch information
RaulPPelaez committed Jun 27, 2023
1 parent e26dd40 commit a116847
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 66 deletions.
46 changes: 38 additions & 8 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,47 @@ def test_forward_output_modules(model_name, output_model, dtype):


@mark.parametrize("model_name", models.__all__)
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward_torchscript(model_name, dtype):
if model_name == "tensornet":
pytest.skip("TensorNet does not support torchscript.")
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript(model_name, device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
z = z.to(device)
pos = pos.to(device)
batch = batch.to(device)
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True, dtype=dtype))
)
model(z, pos, batch=batch)
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
y, neg_dy = model(z, pos, batch=batch)
grad_outputs = [torch.ones_like(neg_dy)]
ddy = torch.autograd.grad(
[neg_dy],
[pos],
grad_outputs=grad_outputs,
)[0]

@mark.parametrize("model_name", models.__all__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
#Repeat the input to make it dynamic
for rep in range(0, 5):
print(rep)
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
ddy = torch.autograd.grad(
[neg_dy],
[posi],
grad_outputs=grad_outputs,
)[0]

@mark.parametrize("model_name", models.__all__)
def test_seed(model_name):
Expand All @@ -59,7 +90,6 @@ def test_seed(model_name):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert (p1 == p2).all(), "Parameters don't match although using the same seed."


@mark.parametrize("model_name", models.__all__)
@mark.parametrize(
"output_model",
Expand Down
131 changes: 76 additions & 55 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@

# Creates a skew-symmetric tensor from a vector
def vector_to_skewtensor(vector):
tensor = torch.cross(
*torch.broadcast_tensors(
vector[..., None], torch.eye(3, 3, device=vector.device, dtype=vector.dtype)[None, None]
)
batch_size = vector.size(0)
zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype)
tensor = torch.stack(
(
zero,
-vector[:, 2],
vector[:, 1],
vector[:, 2],
zero,
-vector[:, 0],
-vector[:, 1],
vector[:, 0],
zero,
),
dim=1,
)
tensor = tensor.view(-1, 3, 3)
return tensor.squeeze(0)


Expand All @@ -43,9 +55,9 @@ def decompose_tensor(tensor):

# Modifies tensor by multiplying invariant features to irreducible components
def new_radial_tensor(I, A, S, f_I, f_A, f_S):
I = (f_I)[..., None, None] * I
A = (f_A)[..., None, None] * A
S = (f_S)[..., None, None] * S
I = f_I[..., None, None] * I
A = f_A[..., None, None] * A
S = f_S[..., None, None] * S
return I, A, S


Expand Down Expand Up @@ -102,6 +114,7 @@ def __init__(
dtype=torch.float32,
):
super(TensorNet, self).__init__()

assert rbf_type in rbf_class_mapping, (
f'Unknown RBF type "{rbf_type}". '
f'Choose from {", ".join(rbf_class_mapping.keys())}.'
Expand All @@ -110,6 +123,7 @@ def __init__(
f'Unknown activation function "{activation}". '
f'Choose from {", ".join(act_class_mapping.keys())}.'
)

assert equivariance_invariance_group in ["O(3)", "SO(3)"], (
f'Unknown group "{equivariance_invariance_group}". '
f"Choose O(3) or SO(3)."
Expand Down Expand Up @@ -139,6 +153,7 @@ def __init__(
max_z,
dtype,
).jittable()

self.layers = nn.ModuleList()
if num_layers != 0:
for _ in range(num_layers):
Expand All @@ -160,23 +175,34 @@ def __init__(

def reset_parameters(self):
self.tensor_embedding.reset_parameters()
for i in range(self.num_layers):
self.layers[i].reset_parameters()
for layer in self.layers:
layer.reset_parameters()
self.linear.reset_parameters()
self.out_norm.reset_parameters()

def forward(
self, z, pos, batch, q: Optional[Tensor] = None, s: Optional[Tensor] = None
):
self,
z: Tensor,
pos: Tensor,
batch: Tensor,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:

# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch)
# This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
assert (
edge_vec is not None
), "Distance module did not return directional information"

# Expand distances with radial basis functions
edge_attr = self.distance_expansion(edge_weight)
# Embedding from edge-wise tensors to node-wise tensors
X = self.tensor_embedding(z, edge_index, edge_weight, edge_vec, edge_attr)
# Interaction layers
for i in range(self.num_layers):
X = self.layers[i](X, edge_index, edge_weight, edge_attr)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr)
I, A, S = decompose_tensor(X)
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
x = self.out_norm(x)
Expand Down Expand Up @@ -208,15 +234,10 @@ def __init__(
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
self.act = activation()
self.linears_tensor = nn.ModuleList()
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
for _ in range(3):
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False)
)
self.linears_scalar = nn.ModuleList()
self.linears_scalar.append(
nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype)
Expand All @@ -239,16 +260,26 @@ def reset_parameters(self):
linear.reset_parameters()
self.init_norm.reset_parameters()

def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr):
def forward(
self,
z: Tensor,
edge_index: Tensor,
edge_weight: Tensor,
edge_vec: Tensor,
edge_attr: Tensor,
):

Z = self.emb(z)
C = self.cutoff(edge_weight)
W1 = (self.distance_proj1(edge_attr)) * C.view(-1, 1)
W2 = (self.distance_proj2(edge_attr)) * C.view(-1, 1)
W3 = (self.distance_proj3(edge_attr)) * C.view(-1, 1)
W1 = self.distance_proj1(edge_attr) * C.view(-1, 1)
W2 = self.distance_proj2(edge_attr) * C.view(-1, 1)
W3 = self.distance_proj3(edge_attr) * C.view(-1, 1)
mask = edge_index[0] != edge_index[1]
edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1)
Iij, Aij, Sij = new_radial_tensor(
torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[None, None, :, :],
torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[
None, None, :, :
],
vector_to_skewtensor(edge_vec)[..., None, :, :],
vector_to_symtensor(edge_vec)[..., None, :, :],
W1,
Expand All @@ -262,11 +293,12 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr):
I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
for j in range(len(self.linears_scalar)):
norm = self.act(self.linears_scalar[j](norm))
for linear_scalar in self.linears_scalar:
norm = self.act(linear_scalar(norm))
norm = norm.reshape(norm.shape[0], self.hidden_channels, 3)
I, A, S = new_radial_tensor(I, A, S, norm[..., 0], norm[..., 1], norm[..., 2])
X = I + A + S

return X

def message(self, Z_i, Z_j, I, A, S):
Expand All @@ -275,6 +307,7 @@ def message(self, Z_i, Z_j, I, A, S):
I = Zij[..., None, None] * I
A = Zij[..., None, None] * A
S = Zij[..., None, None] * S

return I, A, S

def aggregate(
Expand All @@ -284,10 +317,12 @@ def aggregate(
ptr: Optional[torch.Tensor],
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

I, A, S = features
I = scatter(I, index, dim=self.node_dim, dim_size=dim_size)
A = scatter(A, index, dim=self.node_dim, dim_size=dim_size)
S = scatter(S, index, dim=self.node_dim, dim_size=dim_size)

return I, A, S

def update(
Expand Down Expand Up @@ -321,24 +356,10 @@ def __init__(
nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype)
)
self.linears_tensor = nn.ModuleList()
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
)
for _ in range(6):
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False)
)
self.act = activation()
self.equivariance_invariance_group = equivariance_invariance_group
self.reset_parameters()
Expand All @@ -350,9 +371,10 @@ def reset_parameters(self):
linear.reset_parameters()

def forward(self, X, edge_index, edge_weight, edge_attr):

C = self.cutoff(edge_weight)
for i in range(len(self.linears_scalar)):
edge_attr = self.act(self.linears_scalar[i](edge_attr))
for linear_scalar in self.linears_scalar:
edge_attr = self.act(linear_scalar(edge_attr))
edge_attr = (edge_attr * C.view(-1, 1)).reshape(
edge_attr.shape[0], self.hidden_channels, 3
)
Expand All @@ -374,19 +396,17 @@ def forward(self, X, edge_index, edge_weight, edge_attr):
if self.equivariance_invariance_group == "SO(3)":
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(2 * B)
norm = tensor_norm(I + A + S)
I = I / (norm + 1)[..., None, None]
A = A / (norm + 1)[..., None, None]
S = S / (norm + 1)[..., None, None]
normp1 = (tensor_norm(I + A + S) + 1)[..., None, None]
I, A, S = I / normp1, A / normp1, S / normp1
I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
dX = dX + torch.matmul(dX, dX)
X = X + dX
X = X + dX + dX**2
return X

def message(self, I_j, A_j, S_j, edge_attr):

I, A, S = new_radial_tensor(
I_j, A_j, S_j, edge_attr[..., 0], edge_attr[..., 1], edge_attr[..., 2]
)
Expand All @@ -399,6 +419,7 @@ def aggregate(
ptr: Optional[torch.Tensor],
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

I, A, S = features
I = scatter(I, index, dim=self.node_dim, dim_size=dim_size)
A = scatter(A, index, dim=self.node_dim, dim_size=dim_size)
Expand Down
12 changes: 11 additions & 1 deletion torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.functional import mse_loss, l1_loss
from torch import Tensor
from typing import Optional, Dict, Tuple

from pytorch_lightning import LightningModule
from torchmdnet.models.model import create_model, load_model
Expand Down Expand Up @@ -55,7 +57,15 @@ def configure_optimizers(self):
}
return [optimizer], [lr_scheduler]

def forward(self, z, pos, batch=None, q=None, s=None, extra_args=None):
def forward(self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Tensor]]:

return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args)

def training_step(self, batch, batch_idx):
Expand Down
5 changes: 3 additions & 2 deletions torchmdnet/priors/atomref.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torchmdnet.priors.base import BasePrior
from typing import Optional, Dict
import torch
from torch import nn
from torch import nn, Tensor
from pytorch_lightning.utilities import rank_zero_warn


Expand Down Expand Up @@ -37,5 +38,5 @@ def reset_parameters(self):
def get_init_args(self):
return dict(max_z=self.initial_atomref.size(0))

def pre_reduce(self, x, z, pos, batch, extra_args):
def pre_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Tensor, extra_args: Optional[Dict[str, Tensor]]):
return x + self.atomref(z)

0 comments on commit a116847

Please sign in to comment.