1313
1414# Creates a skew-symmetric tensor from a vector
1515def vector_to_skewtensor (vector ):
16- tensor = torch .cross (
17- * torch .broadcast_tensors (
18- vector [..., None ], torch .eye (3 , 3 , device = vector .device , dtype = vector .dtype )[None , None ]
19- )
16+ batch_size = vector .size (0 )
17+ zero = torch .zeros (batch_size , device = vector .device , dtype = vector .dtype )
18+ tensor = torch .stack (
19+ (
20+ zero ,
21+ - vector [:, 2 ],
22+ vector [:, 1 ],
23+ vector [:, 2 ],
24+ zero ,
25+ - vector [:, 0 ],
26+ - vector [:, 1 ],
27+ vector [:, 0 ],
28+ zero ,
29+ ),
30+ dim = 1 ,
2031 )
32+ tensor = tensor .view (- 1 , 3 , 3 )
2133 return tensor .squeeze (0 )
2234
2335
@@ -43,9 +55,9 @@ def decompose_tensor(tensor):
4355
4456# Modifies tensor by multiplying invariant features to irreducible components
4557def new_radial_tensor (I , A , S , f_I , f_A , f_S ):
46- I = ( f_I ) [..., None , None ] * I
47- A = ( f_A ) [..., None , None ] * A
48- S = ( f_S ) [..., None , None ] * S
58+ I = f_I [..., None , None ] * I
59+ A = f_A [..., None , None ] * A
60+ S = f_S [..., None , None ] * S
4961 return I , A , S
5062
5163
@@ -102,6 +114,7 @@ def __init__(
102114 dtype = torch .float32 ,
103115 ):
104116 super (TensorNet , self ).__init__ ()
117+
105118 assert rbf_type in rbf_class_mapping , (
106119 f'Unknown RBF type "{ rbf_type } ". '
107120 f'Choose from { ", " .join (rbf_class_mapping .keys ())} .'
@@ -110,6 +123,7 @@ def __init__(
110123 f'Unknown activation function "{ activation } ". '
111124 f'Choose from { ", " .join (act_class_mapping .keys ())} .'
112125 )
126+
113127 assert equivariance_invariance_group in ["O(3)" , "SO(3)" ], (
114128 f'Unknown group "{ equivariance_invariance_group } ". '
115129 f"Choose O(3) or SO(3)."
@@ -139,6 +153,7 @@ def __init__(
139153 max_z ,
140154 dtype ,
141155 ).jittable ()
156+
142157 self .layers = nn .ModuleList ()
143158 if num_layers != 0 :
144159 for _ in range (num_layers ):
@@ -160,23 +175,34 @@ def __init__(
160175
161176 def reset_parameters (self ):
162177 self .tensor_embedding .reset_parameters ()
163- for i in range ( self .num_layers ) :
164- self . layers [ i ] .reset_parameters ()
178+ for layer in self .layers :
179+ layer .reset_parameters ()
165180 self .linear .reset_parameters ()
166181 self .out_norm .reset_parameters ()
167182
168183 def forward (
169- self , z , pos , batch , q : Optional [Tensor ] = None , s : Optional [Tensor ] = None
170- ):
184+ self ,
185+ z : Tensor ,
186+ pos : Tensor ,
187+ batch : Tensor ,
188+ q : Optional [Tensor ] = None ,
189+ s : Optional [Tensor ] = None ,
190+ ) -> Tuple [Tensor , Optional [Tensor ], Tensor , Tensor , Tensor ]:
191+
171192 # Obtain graph, with distances and relative position vectors
172193 edge_index , edge_weight , edge_vec = self .distance (pos , batch )
194+ # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
195+ assert (
196+ edge_vec is not None
197+ ), "Distance module did not return directional information"
198+
173199 # Expand distances with radial basis functions
174200 edge_attr = self .distance_expansion (edge_weight )
175201 # Embedding from edge-wise tensors to node-wise tensors
176202 X = self .tensor_embedding (z , edge_index , edge_weight , edge_vec , edge_attr )
177203 # Interaction layers
178- for i in range ( self .num_layers ) :
179- X = self . layers [ i ] (X , edge_index , edge_weight , edge_attr )
204+ for layer in self .layers :
205+ X = layer (X , edge_index , edge_weight , edge_attr )
180206 I , A , S = decompose_tensor (X )
181207 x = torch .cat ((tensor_norm (I ), tensor_norm (A ), tensor_norm (S )), dim = - 1 )
182208 x = self .out_norm (x )
@@ -208,15 +234,10 @@ def __init__(
208234 self .emb2 = nn .Linear (2 * hidden_channels , hidden_channels , dtype = dtype )
209235 self .act = activation ()
210236 self .linears_tensor = nn .ModuleList ()
211- self .linears_tensor .append (
212- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
213- )
214- self .linears_tensor .append (
215- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
216- )
217- self .linears_tensor .append (
218- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
219- )
237+ for _ in range (3 ):
238+ self .linears_tensor .append (
239+ nn .Linear (hidden_channels , hidden_channels , bias = False )
240+ )
220241 self .linears_scalar = nn .ModuleList ()
221242 self .linears_scalar .append (
222243 nn .Linear (hidden_channels , 2 * hidden_channels , bias = True , dtype = dtype )
@@ -239,16 +260,26 @@ def reset_parameters(self):
239260 linear .reset_parameters ()
240261 self .init_norm .reset_parameters ()
241262
242- def forward (self , z , edge_index , edge_weight , edge_vec , edge_attr ):
263+ def forward (
264+ self ,
265+ z : Tensor ,
266+ edge_index : Tensor ,
267+ edge_weight : Tensor ,
268+ edge_vec : Tensor ,
269+ edge_attr : Tensor ,
270+ ):
271+
243272 Z = self .emb (z )
244273 C = self .cutoff (edge_weight )
245- W1 = ( self .distance_proj1 (edge_attr ) ) * C .view (- 1 , 1 )
246- W2 = ( self .distance_proj2 (edge_attr ) ) * C .view (- 1 , 1 )
247- W3 = ( self .distance_proj3 (edge_attr ) ) * C .view (- 1 , 1 )
274+ W1 = self .distance_proj1 (edge_attr ) * C .view (- 1 , 1 )
275+ W2 = self .distance_proj2 (edge_attr ) * C .view (- 1 , 1 )
276+ W3 = self .distance_proj3 (edge_attr ) * C .view (- 1 , 1 )
248277 mask = edge_index [0 ] != edge_index [1 ]
249278 edge_vec [mask ] = edge_vec [mask ] / torch .norm (edge_vec [mask ], dim = 1 ).unsqueeze (1 )
250279 Iij , Aij , Sij = new_radial_tensor (
251- torch .eye (3 , 3 , device = edge_vec .device , dtype = edge_vec .dtype )[None , None , :, :],
280+ torch .eye (3 , 3 , device = edge_vec .device , dtype = edge_vec .dtype )[
281+ None , None , :, :
282+ ],
252283 vector_to_skewtensor (edge_vec )[..., None , :, :],
253284 vector_to_symtensor (edge_vec )[..., None , :, :],
254285 W1 ,
@@ -262,11 +293,12 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr):
262293 I = self .linears_tensor [0 ](I .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
263294 A = self .linears_tensor [1 ](A .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
264295 S = self .linears_tensor [2 ](S .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
265- for j in range ( len ( self .linears_scalar )) :
266- norm = self .act (self . linears_scalar [ j ] (norm ))
296+ for linear_scalar in self .linears_scalar :
297+ norm = self .act (linear_scalar (norm ))
267298 norm = norm .reshape (norm .shape [0 ], self .hidden_channels , 3 )
268299 I , A , S = new_radial_tensor (I , A , S , norm [..., 0 ], norm [..., 1 ], norm [..., 2 ])
269300 X = I + A + S
301+
270302 return X
271303
272304 def message (self , Z_i , Z_j , I , A , S ):
@@ -275,6 +307,7 @@ def message(self, Z_i, Z_j, I, A, S):
275307 I = Zij [..., None , None ] * I
276308 A = Zij [..., None , None ] * A
277309 S = Zij [..., None , None ] * S
310+
278311 return I , A , S
279312
280313 def aggregate (
@@ -284,10 +317,12 @@ def aggregate(
284317 ptr : Optional [torch .Tensor ],
285318 dim_size : Optional [int ],
286319 ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
320+
287321 I , A , S = features
288322 I = scatter (I , index , dim = self .node_dim , dim_size = dim_size )
289323 A = scatter (A , index , dim = self .node_dim , dim_size = dim_size )
290324 S = scatter (S , index , dim = self .node_dim , dim_size = dim_size )
325+
291326 return I , A , S
292327
293328 def update (
@@ -321,24 +356,10 @@ def __init__(
321356 nn .Linear (2 * hidden_channels , 3 * hidden_channels , bias = True , dtype = dtype )
322357 )
323358 self .linears_tensor = nn .ModuleList ()
324- self .linears_tensor .append (
325- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
326- )
327- self .linears_tensor .append (
328- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
329- )
330- self .linears_tensor .append (
331- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
332- )
333- self .linears_tensor .append (
334- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
335- )
336- self .linears_tensor .append (
337- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
338- )
339- self .linears_tensor .append (
340- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
341- )
359+ for _ in range (6 ):
360+ self .linears_tensor .append (
361+ nn .Linear (hidden_channels , hidden_channels , bias = False )
362+ )
342363 self .act = activation ()
343364 self .equivariance_invariance_group = equivariance_invariance_group
344365 self .reset_parameters ()
@@ -350,9 +371,10 @@ def reset_parameters(self):
350371 linear .reset_parameters ()
351372
352373 def forward (self , X , edge_index , edge_weight , edge_attr ):
374+
353375 C = self .cutoff (edge_weight )
354- for i in range ( len ( self .linears_scalar )) :
355- edge_attr = self .act (self . linears_scalar [ i ] (edge_attr ))
376+ for linear_scalar in self .linears_scalar :
377+ edge_attr = self .act (linear_scalar (edge_attr ))
356378 edge_attr = (edge_attr * C .view (- 1 , 1 )).reshape (
357379 edge_attr .shape [0 ], self .hidden_channels , 3
358380 )
@@ -374,19 +396,17 @@ def forward(self, X, edge_index, edge_weight, edge_attr):
374396 if self .equivariance_invariance_group == "SO(3)" :
375397 B = torch .matmul (Y , msg )
376398 I , A , S = decompose_tensor (2 * B )
377- norm = tensor_norm (I + A + S )
378- I = I / (norm + 1 )[..., None , None ]
379- A = A / (norm + 1 )[..., None , None ]
380- S = S / (norm + 1 )[..., None , None ]
399+ normp1 = (tensor_norm (I + A + S ) + 1 )[..., None , None ]
400+ I , A , S = I / normp1 , A / normp1 , S / normp1
381401 I = self .linears_tensor [3 ](I .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
382402 A = self .linears_tensor [4 ](A .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
383403 S = self .linears_tensor [5 ](S .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
384404 dX = I + A + S
385- dX = dX + torch .matmul (dX , dX )
386- X = X + dX
405+ X = X + dX + dX ** 2
387406 return X
388407
389408 def message (self , I_j , A_j , S_j , edge_attr ):
409+
390410 I , A , S = new_radial_tensor (
391411 I_j , A_j , S_j , edge_attr [..., 0 ], edge_attr [..., 1 ], edge_attr [..., 2 ]
392412 )
@@ -399,6 +419,7 @@ def aggregate(
399419 ptr : Optional [torch .Tensor ],
400420 dim_size : Optional [int ],
401421 ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
422+
402423 I , A , S = features
403424 I = scatter (I , index , dim = self .node_dim , dim_size = dim_size )
404425 A = scatter (A , index , dim = self .node_dim , dim_size = dim_size )
0 commit comments