Skip to content

Commit

Permalink
rework template generation to be more general and powerful. Fix HDNNP…
Browse files Browse the repository at this point in the history
…2nd.
  • Loading branch information
PatReis committed Dec 7, 2023
1 parent 227c089 commit 933d1c4
Show file tree
Hide file tree
Showing 25 changed files with 563 additions and 247 deletions.
9 changes: 6 additions & 3 deletions kgcnn/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def build(self, input_shape):
input_shape[0], input_shape[1], [])
for i in range(self._depth):
self.mlp_dense_layer_list[i].build([x_shape, r_shape])
x_shape = self.mlp_dense_layer_list[i].compute_output_shape(x_shape)
x_shape = self.mlp_dense_layer_list[i].compute_output_shape([x_shape, r_shape])
if self._conf_use_dropout[i]:
self.mlp_dropout_layer_list[i].build(x_shape)
if self._conf_use_normalization[i]:
Expand All @@ -416,13 +416,16 @@ def call(self, inputs, **kwargs):
Returns:
Tensor: MLP forward pass.
"""
x, relations = inputs
x, relations, batch = (inputs[0], inputs[1], inputs[2:]) if len(inputs) > 2 else (inputs[0], inputs[1], [])
for i in range(self._depth):
x = self.mlp_dense_layer_list[i]([x, relations], **kwargs)
if self._conf_use_dropout[i]:
x = self.mlp_dropout_layer_list[i](x, **kwargs)
if self._conf_use_normalization[i]:
x = self.mlp_norm_layer_list[i](x, **kwargs)
if self.is_graph_norm_layer[i]:
x = self.mlp_norm_layer_list[i]([x]+batch, **kwargs)
else:
x = self.mlp_norm_layer_list[i](x, **kwargs)
x = self.mlp_activation_layer_list[i](x, **kwargs)
out = x
return out
Expand Down
7 changes: 6 additions & 1 deletion kgcnn/literature/AttentiveFP/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ def make_model(inputs: list = None,
model_inputs = [Input(**x) for x in inputs]

di_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 0, 1],
index_assignment=[None, None, 0]
)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs

Expand Down
5 changes: 2 additions & 3 deletions kgcnn/literature/CGCNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def make_crystal_model(inputs: list = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=int(not make_distances),
has_nodes=1 + int(make_distances),
has_crystal_input=2+2*int(representation == "asu")
mask_assignment=[0, 0 if make_distances else 1, 1, 1, None] + ([0, 1] if representation == "asu" else []),
index_assignment=[None, None, 0, None, None] + ([None, None] if representation == "asu" else [])
)

if representation == "asu":
Expand Down
5 changes: 2 additions & 3 deletions kgcnn/literature/CMPNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def make_model(name: str = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=True, has_edges=True,
has_angle_indices=True, # Treat reverse indices as edge indices
has_edge_indices=True
mask_assignment=[0, 1, 1, 2],
index_assignment=[None, None, 0, 2]
)

n, ed, edi, e_pairs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di
Expand Down
6 changes: 2 additions & 4 deletions kgcnn/literature/DGIN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,8 @@ def make_model(name: str = None,
di = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=True, has_edges=True,
has_graph_state=use_graph_state,
has_angle_indices=True, # Treat reverse indices as edge indices
has_edge_indices=True
index_assignment=[None, None, 0, 2] + ([None] if use_graph_state else []),
mask_assignment=[0, 1, 1, 2] + ([None] if use_graph_state else [])
)

if use_graph_state:
Expand Down
5 changes: 2 additions & 3 deletions kgcnn/literature/DMPNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ def make_model(name: str = None,

di = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=True, has_edges=True, has_graph_state=use_graph_state,
has_angle_indices=True, # Treat reverse indices as edge indices
has_edge_indices=True
mask_assignment=[0,1,1,2] + ([None] if use_graph_state else []),
index_assignment=[None, None, 0, 2] + ([None] if use_graph_state else [])
)

if use_graph_state:
Expand Down
11 changes: 4 additions & 7 deletions kgcnn/literature/DimeNetPP/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ def make_model(inputs: list = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False,
has_nodes=2,
has_angle_indices=True,
mask_assignment=[0, 0, 1, 2],
index_assignment=[None, None, 0, 2]
)

n, x, edi, adi, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angles = dj
Expand Down Expand Up @@ -333,10 +332,8 @@ def make_crystal_model(inputs: list = None,
disjoint_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False,
has_nodes=2,
has_angle_indices=True,
has_crystal_input=2
index_assignment=[None, None, 0, 2, None, None],
mask_assignment=[0, 0, 1, 2, 1, None]
)
n, x, edi, angi, img, lattice, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angles = disjoint_inputs

Expand Down
4 changes: 2 additions & 2 deletions kgcnn/literature/EGNN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def make_model(name: str = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=True,
has_nodes=2
index_assignment=[None, None, None, 0],
mask_assignment=[0, 0, 1, 1]
)

n, x, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj
Expand Down
7 changes: 6 additions & 1 deletion kgcnn/literature/GAT/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ def make_model(inputs: list = None,
model_inputs = [Input(**x) for x in inputs]

di_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs

Expand Down
7 changes: 6 additions & 1 deletion kgcnn/literature/GATv2/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def make_model(inputs: list = None,
model_inputs = [Input(**x) for x in inputs]

dj_model_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_model_inputs

Expand Down
7 changes: 6 additions & 1 deletion kgcnn/literature/GCN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def make_model(inputs: list = None,
model_inputs = [Input(**x) for x in inputs]

dj_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs

Expand Down
3 changes: 2 additions & 1 deletion kgcnn/literature/GIN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def make_model(inputs: list = None,

disjoint_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False
mask_assignment=[0, 1],
index_assignment=[None, 0]
)

n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs
Expand Down
4 changes: 3 additions & 1 deletion kgcnn/literature/GNNFilm/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def make_model(inputs: list = None,
dj_inputs = template_cast_list_input(
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

n, er, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs
Expand Down
7 changes: 6 additions & 1 deletion kgcnn/literature/GraphSAGE/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def make_model(inputs: list = None,
model_inputs = [Input(**x) for x in inputs]

dj_inputs = template_cast_list_input(
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs)
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs

Expand Down
23 changes: 10 additions & 13 deletions kgcnn/literature/HDNNP2nd/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,11 @@ def make_model_weighted(inputs: list = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False,
has_nodes=2,
has_angle_indices=True
mask_assignment=[0, 0, 1, 2],
index_assignment=[None, None, 0, 0]
)

n, x, disjoint_indices, ang_ind, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj
n, x, disjoint_indices, ang_ind, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angle = dj

out = model_disjoint_weighted(
[n, x, disjoint_indices, ang_ind, batch_id_node, count_nodes],
Expand Down Expand Up @@ -267,12 +266,11 @@ def make_model_behler(inputs: list = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False,
has_nodes=2,
has_angle_indices=True
mask_assignment=[0, 0, 1, 2],
index_assignment=[None, None, 0, 0]
)

n, x, disjoint_indices, ang_index, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj
n, x, disjoint_indices, ang_index, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angle = dj

out = model_disjoint_behler(
[n, x, disjoint_indices, ang_index, batch_id_node, count_nodes],
Expand Down Expand Up @@ -400,13 +398,12 @@ def make_model_atom_wise(inputs: list = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_edges=False,
has_nodes=2,
has_angle_indices=False,
has_edge_indices=False
mask_assignment=[0, 0],
index_assignment=[None, None]
)

n, x, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj
n, x, batch_id_node, node_id, count_nodes = dj
batch_id_edge, edge_id, count_edges = None, None, None

out = model_disjoint_atom_wise(
[n, x, batch_id_node, count_nodes],
Expand Down
30 changes: 15 additions & 15 deletions kgcnn/literature/HDNNP2nd/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.norm import GraphBatchNormalization
from ._wacsf import wACSFRad, wACSFAng
from ._acsf import ACSFG2, ACSFG4
from ._acsf import ACSFG2, ACSFG4, ACSFConstNormalization


def model_disjoint_weighted(
Expand All @@ -19,7 +19,7 @@ def model_disjoint_weighted(
output_mlp: dict = None
):
# Make input
node_input, xyz_input, edge_index_input, angle_index_input = inputs
node_input, xyz_input, edge_index_input, angle_index_input, batch_id_node, count_nodes = inputs

# ACSF representation.
rep_rad = wACSFRad(**w_acsf_rad_kwargs)([node_input, xyz_input, edge_index_input])
Expand All @@ -28,22 +28,22 @@ def model_disjoint_weighted(

# Normalization
if normalize_kwargs:
rep = GraphBatchNormalization(**normalize_kwargs)(rep)
rep = GraphBatchNormalization(**normalize_kwargs)([rep, batch_id_node, count_nodes])
if const_normalize_kwargs:
rep = ACSFConstNormalization(**const_normalize_kwargs)(rep)

# learnable NN.
n = RelationalMLP(**mlp_kwargs)([rep, node_input])
n = RelationalMLP(**mlp_kwargs)([rep, node_input, batch_id_node, count_nodes])

# Output embedding choice
if output_embedding == 'graph':
out = PoolingNodes(**node_pooling_args)(n)
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node])
if use_output_mlp:
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = n
if use_output_mlp:
out = GraphMLP(**output_mlp)(out)
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `HDNNP2nd` .")

Expand All @@ -63,7 +63,7 @@ def model_disjoint_behler(
output_mlp: dict = None
):
# Make input
node_input, xyz_input, edge_index_input, angle_index_input = inputs
node_input, xyz_input, edge_index_input, angle_index_input, batch_id_node, count_nodes = inputs

# ACSF representation.
rep_g2 = ACSFG2(**ACSFG2.make_param_table(**g2_kwargs))([node_input, xyz_input, edge_index_input])
Expand All @@ -72,22 +72,22 @@ def model_disjoint_behler(

# Normalization
if normalize_kwargs:
rep = GraphBatchNormalization(**normalize_kwargs)(rep)
rep = GraphBatchNormalization(**normalize_kwargs)([rep, batch_id_node, count_nodes])
if const_normalize_kwargs:
rep = ACSFConstNormalization(**const_normalize_kwargs)(rep)

# learnable NN.
n = RelationalMLP(**mlp_kwargs)([rep, node_input])
n = RelationalMLP(**mlp_kwargs)([rep, node_input, batch_id_node, count_nodes])

# Output embedding choice
if output_embedding == 'graph':
out = PoolingNodes(**node_pooling_args)(n)
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node])
if use_output_mlp:
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = n
if use_output_mlp:
out = GraphMLP(**output_mlp)(out)
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `HDNNP2nd`")

Expand All @@ -103,20 +103,20 @@ def model_disjoint_atom_wise(
output_mlp: dict = None
):
# Make input
node_input, rep_input = inputs
node_input, rep_input, batch_id_node, count_nodes = inputs

# learnable NN.
n = RelationalMLP(**mlp_kwargs)([rep_input, node_input])
n = RelationalMLP(**mlp_kwargs)([rep_input, node_input, batch_id_node, count_nodes])

# Output embedding choice
if output_embedding == 'graph':
out = PoolingNodes(**node_pooling_args)(n)
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node])
if use_output_mlp:
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = n
if use_output_mlp:
out = GraphMLP(**output_mlp)(out)
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `HDNNP2nd`")

Expand Down
7 changes: 3 additions & 4 deletions kgcnn/literature/HamNet/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"inputs": [
{'shape': (None,), 'name': "node_number", 'dtype': 'int64'},
{'shape': (None, 3), 'name': "node_coordinates", 'dtype': 'float32'},
{'shape': (None, 32), 'name': "edge_attributes", 'dtype': 'float32'},
{'shape': (None, 64), 'name': "edge_attributes", 'dtype': 'float32'},
{'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64'},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_edges", "dtype": "int64"}
Expand Down Expand Up @@ -130,9 +130,8 @@ def make_model(name: str = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
has_nodes=2,
has_edges=True,
has_edge_indices=True
mask_assignment=[0, 0, 1, 1],
index_assignment=[None, None, None, 0]
)

n, x, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs
Expand Down
Loading

0 comments on commit 933d1c4

Please sign in to comment.