From 933d1c4ae94fdd78e7c84199c4a7885cff9a879b Mon Sep 17 00:00:00 2001 From: PatReis Date: Thu, 7 Dec 2023 18:01:15 +0100 Subject: [PATCH] rework template generation to be more general and powerful. Fix HDNNP2nd. --- kgcnn/layers/mlp.py | 9 +- kgcnn/literature/AttentiveFP/_make.py | 7 +- kgcnn/literature/CGCNN/_make.py | 5 +- kgcnn/literature/CMPNN/_make.py | 5 +- kgcnn/literature/DGIN/_make.py | 6 +- kgcnn/literature/DMPNN/_make.py | 5 +- kgcnn/literature/DimeNetPP/_make.py | 11 +- kgcnn/literature/EGNN/_make.py | 4 +- kgcnn/literature/GAT/_make.py | 7 +- kgcnn/literature/GATv2/_make.py | 7 +- kgcnn/literature/GCN/_make.py | 7 +- kgcnn/literature/GIN/_make.py | 3 +- kgcnn/literature/GNNFilm/_make.py | 4 +- kgcnn/literature/GraphSAGE/_make.py | 7 +- kgcnn/literature/HDNNP2nd/_make.py | 23 +-- kgcnn/literature/HDNNP2nd/_model.py | 30 +-- kgcnn/literature/HamNet/_make.py | 7 +- kgcnn/literature/Megnet/_make.py | 12 +- kgcnn/literature/NMPN/_make.py | 179 +++++++++++++++- kgcnn/literature/NMPN/_model.py | 81 +++++++- kgcnn/literature/PAiNN/_make.py | 17 +- kgcnn/literature/RGCN/_make.py | 5 +- kgcnn/literature/Schnet/_make.py | 19 +- kgcnn/models/casting.py | 280 ++++++++++++-------------- training/hyper/hyper_esol.py | 70 +++++++ 25 files changed, 563 insertions(+), 247 deletions(-) diff --git a/kgcnn/layers/mlp.py b/kgcnn/layers/mlp.py index 38269244..b425608a 100644 --- a/kgcnn/layers/mlp.py +++ b/kgcnn/layers/mlp.py @@ -395,7 +395,7 @@ def build(self, input_shape): input_shape[0], input_shape[1], []) for i in range(self._depth): self.mlp_dense_layer_list[i].build([x_shape, r_shape]) - x_shape = self.mlp_dense_layer_list[i].compute_output_shape(x_shape) + x_shape = self.mlp_dense_layer_list[i].compute_output_shape([x_shape, r_shape]) if self._conf_use_dropout[i]: self.mlp_dropout_layer_list[i].build(x_shape) if self._conf_use_normalization[i]: @@ -416,13 +416,16 @@ def call(self, inputs, **kwargs): Returns: Tensor: MLP forward pass. """ - x, relations = inputs + x, relations, batch = (inputs[0], inputs[1], inputs[2:]) if len(inputs) > 2 else (inputs[0], inputs[1], []) for i in range(self._depth): x = self.mlp_dense_layer_list[i]([x, relations], **kwargs) if self._conf_use_dropout[i]: x = self.mlp_dropout_layer_list[i](x, **kwargs) if self._conf_use_normalization[i]: - x = self.mlp_norm_layer_list[i](x, **kwargs) + if self.is_graph_norm_layer[i]: + x = self.mlp_norm_layer_list[i]([x]+batch, **kwargs) + else: + x = self.mlp_norm_layer_list[i](x, **kwargs) x = self.mlp_activation_layer_list[i](x, **kwargs) out = x return out diff --git a/kgcnn/literature/AttentiveFP/_make.py b/kgcnn/literature/AttentiveFP/_make.py index d1916e40..7d91c842 100644 --- a/kgcnn/literature/AttentiveFP/_make.py +++ b/kgcnn/literature/AttentiveFP/_make.py @@ -114,7 +114,12 @@ def make_model(inputs: list = None, model_inputs = [Input(**x) for x in inputs] di_inputs = template_cast_list_input( - model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 0, 1], + index_assignment=[None, None, 0] + ) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs diff --git a/kgcnn/literature/CGCNN/_make.py b/kgcnn/literature/CGCNN/_make.py index 8a970b1a..8b8d3cac 100644 --- a/kgcnn/literature/CGCNN/_make.py +++ b/kgcnn/literature/CGCNN/_make.py @@ -138,9 +138,8 @@ def make_crystal_model(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=int(not make_distances), - has_nodes=1 + int(make_distances), - has_crystal_input=2+2*int(representation == "asu") + mask_assignment=[0, 0 if make_distances else 1, 1, 1, None] + ([0, 1] if representation == "asu" else []), + index_assignment=[None, None, 0, None, None] + ([None, None] if representation == "asu" else []) ) if representation == "asu": diff --git a/kgcnn/literature/CMPNN/_make.py b/kgcnn/literature/CMPNN/_make.py index a1a41af5..6a61522f 100644 --- a/kgcnn/literature/CMPNN/_make.py +++ b/kgcnn/literature/CMPNN/_make.py @@ -137,9 +137,8 @@ def make_model(name: str = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=True, has_edges=True, - has_angle_indices=True, # Treat reverse indices as edge indices - has_edge_indices=True + mask_assignment=[0, 1, 1, 2], + index_assignment=[None, None, 0, 2] ) n, ed, edi, e_pairs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di diff --git a/kgcnn/literature/DGIN/_make.py b/kgcnn/literature/DGIN/_make.py index 0647fcbd..1d80fbc3 100644 --- a/kgcnn/literature/DGIN/_make.py +++ b/kgcnn/literature/DGIN/_make.py @@ -151,10 +151,8 @@ def make_model(name: str = None, di = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=True, has_edges=True, - has_graph_state=use_graph_state, - has_angle_indices=True, # Treat reverse indices as edge indices - has_edge_indices=True + index_assignment=[None, None, 0, 2] + ([None] if use_graph_state else []), + mask_assignment=[0, 1, 1, 2] + ([None] if use_graph_state else []) ) if use_graph_state: diff --git a/kgcnn/literature/DMPNN/_make.py b/kgcnn/literature/DMPNN/_make.py index ebf3e85d..cd5b3cad 100644 --- a/kgcnn/literature/DMPNN/_make.py +++ b/kgcnn/literature/DMPNN/_make.py @@ -132,9 +132,8 @@ def make_model(name: str = None, di = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=True, has_edges=True, has_graph_state=use_graph_state, - has_angle_indices=True, # Treat reverse indices as edge indices - has_edge_indices=True + mask_assignment=[0,1,1,2] + ([None] if use_graph_state else []), + index_assignment=[None, None, 0, 2] + ([None] if use_graph_state else []) ) if use_graph_state: diff --git a/kgcnn/literature/DimeNetPP/_make.py b/kgcnn/literature/DimeNetPP/_make.py index 440838cb..3b8fd094 100644 --- a/kgcnn/literature/DimeNetPP/_make.py +++ b/kgcnn/literature/DimeNetPP/_make.py @@ -144,9 +144,8 @@ def make_model(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=False, - has_nodes=2, - has_angle_indices=True, + mask_assignment=[0, 0, 1, 2], + index_assignment=[None, None, 0, 2] ) n, x, edi, adi, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angles = dj @@ -333,10 +332,8 @@ def make_crystal_model(inputs: list = None, disjoint_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=False, - has_nodes=2, - has_angle_indices=True, - has_crystal_input=2 + index_assignment=[None, None, 0, 2, None, None], + mask_assignment=[0, 0, 1, 2, 1, None] ) n, x, edi, angi, img, lattice, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angles = disjoint_inputs diff --git a/kgcnn/literature/EGNN/_make.py b/kgcnn/literature/EGNN/_make.py index 46f551d6..9308c849 100644 --- a/kgcnn/literature/EGNN/_make.py +++ b/kgcnn/literature/EGNN/_make.py @@ -152,8 +152,8 @@ def make_model(name: str = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=True, - has_nodes=2 + index_assignment=[None, None, None, 0], + mask_assignment=[0, 0, 1, 1] ) n, x, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj diff --git a/kgcnn/literature/GAT/_make.py b/kgcnn/literature/GAT/_make.py index a6fb12e2..e2baf3c9 100644 --- a/kgcnn/literature/GAT/_make.py +++ b/kgcnn/literature/GAT/_make.py @@ -114,7 +114,12 @@ def make_model(inputs: list = None, model_inputs = [Input(**x) for x in inputs] di_inputs = template_cast_list_input( - model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1], + index_assignment=[None, None, 0] + ) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs diff --git a/kgcnn/literature/GATv2/_make.py b/kgcnn/literature/GATv2/_make.py index c1f0b9e6..227680ea 100644 --- a/kgcnn/literature/GATv2/_make.py +++ b/kgcnn/literature/GATv2/_make.py @@ -116,7 +116,12 @@ def make_model(inputs: list = None, model_inputs = [Input(**x) for x in inputs] dj_model_inputs = template_cast_list_input( - model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1], + index_assignment=[None, None, 0] + ) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_model_inputs diff --git a/kgcnn/literature/GCN/_make.py b/kgcnn/literature/GCN/_make.py index e23b4866..120d1fa6 100644 --- a/kgcnn/literature/GCN/_make.py +++ b/kgcnn/literature/GCN/_make.py @@ -111,7 +111,12 @@ def make_model(inputs: list = None, model_inputs = [Input(**x) for x in inputs] dj_inputs = template_cast_list_input( - model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1], + index_assignment=[None, None, 0] + ) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs diff --git a/kgcnn/literature/GIN/_make.py b/kgcnn/literature/GIN/_make.py index fd42e32a..3b4bac1d 100644 --- a/kgcnn/literature/GIN/_make.py +++ b/kgcnn/literature/GIN/_make.py @@ -110,7 +110,8 @@ def make_model(inputs: list = None, disjoint_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=False + mask_assignment=[0, 1], + index_assignment=[None, 0] ) n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs diff --git a/kgcnn/literature/GNNFilm/_make.py b/kgcnn/literature/GNNFilm/_make.py index 25d20ef2..bec390c5 100644 --- a/kgcnn/literature/GNNFilm/_make.py +++ b/kgcnn/literature/GNNFilm/_make.py @@ -114,7 +114,9 @@ def make_model(inputs: list = None, dj_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1], + index_assignment=[None, None, 0] ) n, er, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs diff --git a/kgcnn/literature/GraphSAGE/_make.py b/kgcnn/literature/GraphSAGE/_make.py index 55e9fc4b..dd52f8d9 100644 --- a/kgcnn/literature/GraphSAGE/_make.py +++ b/kgcnn/literature/GraphSAGE/_make.py @@ -120,7 +120,12 @@ def make_model(inputs: list = None, model_inputs = [Input(**x) for x in inputs] dj_inputs = template_cast_list_input( - model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1], + index_assignment=[None, None, 0] + ) n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs diff --git a/kgcnn/literature/HDNNP2nd/_make.py b/kgcnn/literature/HDNNP2nd/_make.py index 8aa9636f..c66a8b67 100644 --- a/kgcnn/literature/HDNNP2nd/_make.py +++ b/kgcnn/literature/HDNNP2nd/_make.py @@ -118,12 +118,11 @@ def make_model_weighted(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=False, - has_nodes=2, - has_angle_indices=True + mask_assignment=[0, 0, 1, 2], + index_assignment=[None, None, 0, 0] ) - n, x, disjoint_indices, ang_ind, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + n, x, disjoint_indices, ang_ind, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angle = dj out = model_disjoint_weighted( [n, x, disjoint_indices, ang_ind, batch_id_node, count_nodes], @@ -267,12 +266,11 @@ def make_model_behler(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=False, - has_nodes=2, - has_angle_indices=True + mask_assignment=[0, 0, 1, 2], + index_assignment=[None, None, 0, 0] ) - n, x, disjoint_indices, ang_index, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + n, x, disjoint_indices, ang_index, batch_id_node, batch_id_edge, batch_id_angles, node_id, edge_id, angle_id, count_nodes, count_edges, count_angle = dj out = model_disjoint_behler( [n, x, disjoint_indices, ang_index, batch_id_node, count_nodes], @@ -400,13 +398,12 @@ def make_model_atom_wise(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=False, - has_nodes=2, - has_angle_indices=False, - has_edge_indices=False + mask_assignment=[0, 0], + index_assignment=[None, None] ) - n, x, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + n, x, batch_id_node, node_id, count_nodes = dj + batch_id_edge, edge_id, count_edges = None, None, None out = model_disjoint_atom_wise( [n, x, batch_id_node, count_nodes], diff --git a/kgcnn/literature/HDNNP2nd/_model.py b/kgcnn/literature/HDNNP2nd/_model.py index 6dd469bf..758609a3 100644 --- a/kgcnn/literature/HDNNP2nd/_model.py +++ b/kgcnn/literature/HDNNP2nd/_model.py @@ -3,7 +3,7 @@ from kgcnn.layers.pooling import PoolingNodes from kgcnn.layers.norm import GraphBatchNormalization from ._wacsf import wACSFRad, wACSFAng -from ._acsf import ACSFG2, ACSFG4 +from ._acsf import ACSFG2, ACSFG4, ACSFConstNormalization def model_disjoint_weighted( @@ -19,7 +19,7 @@ def model_disjoint_weighted( output_mlp: dict = None ): # Make input - node_input, xyz_input, edge_index_input, angle_index_input = inputs + node_input, xyz_input, edge_index_input, angle_index_input, batch_id_node, count_nodes = inputs # ACSF representation. rep_rad = wACSFRad(**w_acsf_rad_kwargs)([node_input, xyz_input, edge_index_input]) @@ -28,22 +28,22 @@ def model_disjoint_weighted( # Normalization if normalize_kwargs: - rep = GraphBatchNormalization(**normalize_kwargs)(rep) + rep = GraphBatchNormalization(**normalize_kwargs)([rep, batch_id_node, count_nodes]) if const_normalize_kwargs: rep = ACSFConstNormalization(**const_normalize_kwargs)(rep) # learnable NN. - n = RelationalMLP(**mlp_kwargs)([rep, node_input]) + n = RelationalMLP(**mlp_kwargs)([rep, node_input, batch_id_node, count_nodes]) # Output embedding choice if output_embedding == 'graph': - out = PoolingNodes(**node_pooling_args)(n) + out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node]) if use_output_mlp: out = MLP(**output_mlp)(out) elif output_embedding == 'node': out = n if use_output_mlp: - out = GraphMLP(**output_mlp)(out) + out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes]) else: raise ValueError("Unsupported output embedding for mode `HDNNP2nd` .") @@ -63,7 +63,7 @@ def model_disjoint_behler( output_mlp: dict = None ): # Make input - node_input, xyz_input, edge_index_input, angle_index_input = inputs + node_input, xyz_input, edge_index_input, angle_index_input, batch_id_node, count_nodes = inputs # ACSF representation. rep_g2 = ACSFG2(**ACSFG2.make_param_table(**g2_kwargs))([node_input, xyz_input, edge_index_input]) @@ -72,22 +72,22 @@ def model_disjoint_behler( # Normalization if normalize_kwargs: - rep = GraphBatchNormalization(**normalize_kwargs)(rep) + rep = GraphBatchNormalization(**normalize_kwargs)([rep, batch_id_node, count_nodes]) if const_normalize_kwargs: rep = ACSFConstNormalization(**const_normalize_kwargs)(rep) # learnable NN. - n = RelationalMLP(**mlp_kwargs)([rep, node_input]) + n = RelationalMLP(**mlp_kwargs)([rep, node_input, batch_id_node, count_nodes]) # Output embedding choice if output_embedding == 'graph': - out = PoolingNodes(**node_pooling_args)(n) + out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node]) if use_output_mlp: out = MLP(**output_mlp)(out) elif output_embedding == 'node': out = n if use_output_mlp: - out = GraphMLP(**output_mlp)(out) + out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes]) else: raise ValueError("Unsupported output embedding for mode `HDNNP2nd`") @@ -103,20 +103,20 @@ def model_disjoint_atom_wise( output_mlp: dict = None ): # Make input - node_input, rep_input = inputs + node_input, rep_input, batch_id_node, count_nodes = inputs # learnable NN. - n = RelationalMLP(**mlp_kwargs)([rep_input, node_input]) + n = RelationalMLP(**mlp_kwargs)([rep_input, node_input, batch_id_node, count_nodes]) # Output embedding choice if output_embedding == 'graph': - out = PoolingNodes(**node_pooling_args)(n) + out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node]) if use_output_mlp: out = MLP(**output_mlp)(out) elif output_embedding == 'node': out = n if use_output_mlp: - out = GraphMLP(**output_mlp)(out) + out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes]) else: raise ValueError("Unsupported output embedding for mode `HDNNP2nd`") diff --git a/kgcnn/literature/HamNet/_make.py b/kgcnn/literature/HamNet/_make.py index 85758343..fb30c691 100644 --- a/kgcnn/literature/HamNet/_make.py +++ b/kgcnn/literature/HamNet/_make.py @@ -28,7 +28,7 @@ "inputs": [ {'shape': (None,), 'name': "node_number", 'dtype': 'int64'}, {'shape': (None, 3), 'name': "node_coordinates", 'dtype': 'float32'}, - {'shape': (None, 32), 'name': "edge_attributes", 'dtype': 'float32'}, + {'shape': (None, 64), 'name': "edge_attributes", 'dtype': 'float32'}, {'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64'}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} @@ -130,9 +130,8 @@ def make_model(name: str = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=2, - has_edges=True, - has_edge_indices=True + mask_assignment=[0, 0, 1, 1], + index_assignment=[None, None, None, 0] ) n, x, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs diff --git a/kgcnn/literature/Megnet/_make.py b/kgcnn/literature/Megnet/_make.py index fac956eb..d2708d8d 100644 --- a/kgcnn/literature/Megnet/_make.py +++ b/kgcnn/literature/Megnet/_make.py @@ -27,6 +27,7 @@ {"shape": (None,), "name": "node_number", "dtype": "int64"}, {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {'shape': [1], 'name': "charge", 'dtype': 'float32'}, {"shape": (), "name": "graph_number", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} @@ -134,9 +135,8 @@ def make_model(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=(not make_distance), - has_nodes=1 + int(make_distance), - has_graph_state=True + mask_assignment=[0, 0 if make_distance else 1, 1, None], + index_assignment=[None, None, 0, None] ) n, x, disjoint_indices, gs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj @@ -303,10 +303,8 @@ def make_crystal_model(inputs: list = None, dj = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=(not make_distance), - has_nodes=1 + int(make_distance), - has_graph_state=True, - has_crystal_input=2 + mask_assignment=[0, 0 if make_distance else 1, 1, None, 1, None], + index_assignment=[None, None, 0, None, None, None] ) n, x, djx, gs, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj diff --git a/kgcnn/literature/NMPN/_make.py b/kgcnn/literature/NMPN/_make.py index 832a21b6..b98e4e27 100644 --- a/kgcnn/literature/NMPN/_make.py +++ b/kgcnn/literature/NMPN/_make.py @@ -131,9 +131,13 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=(not make_distance), has_nodes=1 + int(make_distance)) + disjoint_inputs = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 0 if make_distance else 1, 1], + index_assignment=[None, None, 0] + ) n, x, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs @@ -186,3 +190,172 @@ def set_scale(*args, **kwargs): make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) + + +model_crystal_default = { + "name": "NMPN", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64", "ragged": True}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True}, + {'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64', 'ragged': True}, + {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False} + ], + "input_tensor_type": "ragged", + "cast_disjoint_kwargs": {}, + "input_embedding": None, # deprecated + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "geometric_edge": True, + "make_distance": True, + "expand_distance": True, + "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, + "set2set_args": {"channels": 32, "T": 3, "pooling_method": "sum", + "init_qstar": "0"}, + "pooling_args": {"pooling_method": "sum"}, + "edge_mlp": {"use_bias": True, "activation": "swish", "units": [64, 64, 64]}, + "use_set2set": True, "depth": 3, "node_dim": 64, + "verbose": 10, + "output_embedding": 'graph', + "output_to_tensor": None, + "output_tensor_type": "padded", + "output_scaling": None, + "output_mlp": {"use_bias": [True, True, False], "units": [25, 10, 1], + "activation": ["selu", "selu", "sigmoid"]}, +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_crystal_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + geometric_edge: bool = None, + make_distance: bool = None, + expand_distance: bool = None, + gauss_args: dict = None, + set2set_args: dict = None, + pooling_args: dict = None, + edge_mlp: dict = None, + use_set2set: bool = None, + node_dim: int = None, + depth: int = None, + verbose: int = None, # noqa + name: str = None, + output_embedding: str = None, + output_to_tensor: bool = None, # noqa + output_mlp: dict = None, + output_tensor_type: str = None, + output_scaling: dict = None + ): + r"""Make `NMPN `_ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.NMPN.model_crystal_default`. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, image_translation, lattice, ...]` + with '...' indicating mask or ID tensors following the template below. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layer. + input_edge_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layer. + geometric_edge (bool): Whether the edges are geometric, like distance or coordinates. + make_distance (bool): Whether input is distance or coordinates at in place of edges. + expand_distance (bool): If the edge input are actual edges or node coordinates instead that are expanded to + form edges with a gauss distance basis given edge indices. Expansion uses `gauss_args`. + gauss_args (dict): Dictionary of layer arguments unpacked in :obj:`GaussBasisLayer` layer. + set2set_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingSet2SetEncoder` layer. + pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes`, + `AggregateLocalEdges` layers. + edge_mlp (dict): Dictionary of layer arguments unpacked in :obj:`MLP` layer for edge matrix. + use_set2set (bool): Whether to use :obj:`PoolingSet2SetEncoder` layer. + node_dim (int): Dimension of hidden node embedding. + depth (int): Number of graph embedding units or depth of the network. + verbose (int): Level of verbosity. + name (str): Name of the model. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 0 if make_distance else 1, 1, 1, None], + index_assignment=[None, None, 0, None, None] + ) + + n, x, d_indices, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + + out = model_disjoint( + [n, x, d_indices, img, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + geometric_edge=geometric_edge, + make_distance=make_distance, + expand_distance=expand_distance, + gauss_args=gauss_args, + set2set_args=set2set_args, + pooling_args=pooling_args, + edge_mlp=edge_mlp, + use_set2set=use_set2set, + node_dim=node_dim, + depth=depth, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_crystal_model.__doc__ = make_crystal_model.__doc__ % ( + template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/NMPN/_model.py b/kgcnn/literature/NMPN/_model.py index 7ab4de2f..03b3c4c1 100644 --- a/kgcnn/literature/NMPN/_model.py +++ b/kgcnn/literature/NMPN/_model.py @@ -2,7 +2,7 @@ from ._layers import TrafoEdgeNetMessages from kgcnn.layers.aggr import AggregateLocalEdges from kgcnn.layers.gather import GatherNodesOutgoing, GatherNodesIngoing -from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, GaussBasisLayer +from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, GaussBasisLayer, ShiftPeriodicLattice from kgcnn.layers.mlp import GraphMLP, MLP from kgcnn.layers.modules import Embedding from kgcnn.layers.message import MatMulMessages @@ -28,7 +28,6 @@ def model_disjoint(inputs, depth: int = None, output_embedding: str = None, output_mlp: dict = None): - n0, ed, disjoint_indices, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs # embedding, if no feature dimension @@ -85,3 +84,81 @@ def model_disjoint(inputs, raise ValueError("Unsupported output embedding for mode `NMPN` .") return out + + +def model_disjoint_crystal(inputs, + use_node_embedding: bool = None, + use_edge_embedding: bool = None, + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + geometric_edge: bool = None, + make_distance: bool = None, + expand_distance: bool = None, + gauss_args: dict = None, + set2set_args: dict = None, + pooling_args: dict = None, + edge_mlp: dict = None, + use_set2set: bool = None, + node_dim: int = None, + depth: int = None, + output_embedding: str = None, + output_mlp: dict = None): + n0, ed, disjoint_indices, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs + + # embedding, if no feature dimension + if use_node_embedding: + n0 = Embedding(**input_node_embedding)(n0) + + if not geometric_edge: + if use_edge_embedding: + ed = Embedding(**input_edge_embedding)(ed) + + # If coordinates are in place of edges + if make_distance: + x = ed + pos1, pos2 = NodePosition()([x, disjoint_indices]) + pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice]) + ed = NodeDistanceEuclidean()([pos1, pos2]) + + if expand_distance: + ed = GaussBasisLayer(**gauss_args)(ed) + + # Make hidden dimension + n = ks.layers.Dense(node_dim, activation="linear")(n0) + + # Make edge networks. + edge_net_in = GraphMLP(**edge_mlp)([ed, batch_id_edge, count_edges]) + edge_net_in = TrafoEdgeNetMessages(target_shape=(node_dim, node_dim))(edge_net_in) + edge_net_out = GraphMLP(**edge_mlp)([ed, batch_id_edge, count_edges]) + edge_net_out = TrafoEdgeNetMessages(target_shape=(node_dim, node_dim))(edge_net_out) + + # Gru for node updates + gru = GRUUpdate(node_dim) + + for i in range(0, depth): + n_in = GatherNodesOutgoing()([n, disjoint_indices]) + n_out = GatherNodesIngoing()([n, disjoint_indices]) + m_in = MatMulMessages()([edge_net_in, n_in]) + m_out = MatMulMessages()([edge_net_out, n_out]) + eu = ks.layers.Concatenate(axis=-1)([m_in, m_out]) + eu = AggregateLocalEdges(**pooling_args)([n, eu, disjoint_indices]) # Summing for each node connections + n = gru([n, eu]) + + n = ks.layers.Concatenate(axis=-1)([n0, n]) + + # Output embedding choice + if output_embedding == 'graph': + if use_set2set: + # output + n = ks.layers.Dense(units=set2set_args['channels'], activation="linear")(n) + out = PoolingSet2SetEncoder(**set2set_args)([count_nodes, n, batch_id_node]) + else: + out = PoolingNodes(**pooling_args)([count_nodes, n, batch_id_node]) + out = ks.layers.Flatten()(out) # Flatten() required for to Set2Set output. + out = MLP(**output_mlp)(out) + elif output_embedding == 'node': + out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported output embedding for mode `NMPN` .") + + return out diff --git a/kgcnn/literature/PAiNN/_make.py b/kgcnn/literature/PAiNN/_make.py index 7bd222e0..56e68ce1 100644 --- a/kgcnn/literature/PAiNN/_make.py +++ b/kgcnn/literature/PAiNN/_make.py @@ -120,10 +120,12 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=2 + int(has_equivariant_input), - has_edges=False) + disjoint_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=([0] if has_equivariant_input else []) + [0, 0, 1], + index_assignment=([None] if has_equivariant_input else []) + [None, None, 0 + int(has_equivariant_input)] + ) if not has_equivariant_input: z, x, edi, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs @@ -278,8 +280,11 @@ def make_crystal_model(inputs: list = None, dj_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=2 + int(has_equivariant_input), has_crystal_input=2, - has_edges=False) + mask_assignment=([0] if has_equivariant_input else []) + [ + 0, 0, 1, 1, None], + index_assignment=([0] if has_equivariant_input else []) + [ + None, None, 0 + int(has_equivariant_input), None, None] + ) if not has_equivariant_input: z, x, edi, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs diff --git a/kgcnn/literature/RGCN/_make.py b/kgcnn/literature/RGCN/_make.py index 220185f3..e044ad55 100644 --- a/kgcnn/literature/RGCN/_make.py +++ b/kgcnn/literature/RGCN/_make.py @@ -118,9 +118,8 @@ def make_model(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - has_nodes=True, - has_edges=2, - has_edge_indices=True + mask_assignment=[0, 1, 1, 1], + index_assignment=[None, None, None, 0] ) n, ed, er, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs diff --git a/kgcnn/literature/Schnet/_make.py b/kgcnn/literature/Schnet/_make.py index 69eeda3f..3c19e1d7 100644 --- a/kgcnn/literature/Schnet/_make.py +++ b/kgcnn/literature/Schnet/_make.py @@ -124,9 +124,12 @@ def make_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=(not make_distance), has_nodes=1 + int(make_distance)) + disjoint_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 0 if make_distance else 1, 1], + index_assignment=[None, None, 0] + ) n, x, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs @@ -269,10 +272,12 @@ def make_crystal_model(inputs: list = None, # Make input model_inputs = [Input(**x) for x in inputs] - disjoint_inputs = template_cast_list_input(model_inputs, input_tensor_type=input_tensor_type, - cast_disjoint_kwargs=cast_disjoint_kwargs, - has_edges=(not make_distance), has_nodes=1 + int(make_distance), - has_crystal_input=2) + disjoint_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 0 if make_distance else 1, 1, 1, None], + index_assignment=[None, None, 0, None, None] + ) n, x, djx, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs diff --git a/kgcnn/models/casting.py b/kgcnn/models/casting.py index cb2a591f..11fdf0ef 100644 --- a/kgcnn/models/casting.py +++ b/kgcnn/models/casting.py @@ -8,7 +8,10 @@ def template_cast_output(model_outputs, - output_embedding, output_tensor_type, input_tensor_type, cast_disjoint_kwargs): + output_embedding, + output_tensor_type, + input_tensor_type, + cast_disjoint_kwargs): r"""The standard model output template returns a single tensor of either "graph", "node", or "edge" embeddings specified by :obj:`output_embedding` within the model. The return tensor type is determined by :obj:`output_tensor_type` . Options are: @@ -69,13 +72,8 @@ def template_cast_output(model_outputs, def template_cast_list_input(model_inputs, input_tensor_type, cast_disjoint_kwargs, - has_nodes: Union[int, bool] = True, - has_edges: Union[int, bool] = True, - has_angles: Union[int, bool] = False, - has_edge_indices: Union[int, bool] = True, - has_angle_indices: Union[int, bool] = False, - has_graph_state: Union[int, bool] = False, - has_crystal_input: Union[int, bool] = False, + mask_assignment: list = None, + index_assignment: list = None, return_sub_id: bool = True): r"""Template of listed graph input tensors, which should be compatible to previous kgcnn versions and defines the order as follows: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, ...]` . @@ -154,161 +152,133 @@ def template_cast_list_input(model_inputs, - edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` . - angles_count (Tensor): Tensor of number of angles for each graph of shape `(batch, )` . """ - standard_inputs = [x for x in model_inputs] - - batched_nodes = [] - batched_edges = [] - batched_angles = [] - batched_state = [] - batched_indices = [] - batched_angle_indices = [] - batched_crystal_info = [] - - for i in range(int(has_nodes)): - batched_nodes.append(standard_inputs.pop(0)) - for i in range(int(has_edges)): - batched_edges.append(standard_inputs.pop(0)) - for i in range(int(has_angles)): - batched_angles.append(standard_inputs.pop(0)) - for i in range(int(has_edge_indices)): - batched_indices.append(standard_inputs.pop(0)) - for i in range(int(has_angle_indices)): - batched_angle_indices.append(standard_inputs.pop(0)) - for i in range(int(has_graph_state)): - batched_state.append(standard_inputs.pop(0)) - for i in range(int(has_crystal_input)): - batched_crystal_info.append(standard_inputs.pop(0)) - - batched_id = standard_inputs - - disjoint_nodes = [] - disjoint_edges = [] - disjoint_state = [] - disjoint_angles = [] - disjoint_indices = [] - disjoint_angle_indices = [] - disjoint_crystal_info = [] - disjoint_id = [] + out_tensor = [] + out_batch_id = [] + out_graph_id = [] + out_totals = [] + is_already_disjoint = False if input_tensor_type in ["padded", "masked"]: + if mask_assignment is None or not isinstance(mask_assignment, (list, tuple)): + raise ValueError() - if int(has_angle_indices) > 0: - part_nodes, part_edges, part_angle = batched_id + reduced_mask = [x for x in mask_assignment if x is not None] + if len(reduced_mask) == 0: + num_mask = 0 else: - part_nodes, part_edges = batched_id - part_angle = None - - for x in batched_indices: - _, idx, batch_id_node, batch_id_edge, node_id, edge_id, len_nodes, len_edges = CastBatchedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes[0], x, part_nodes, part_edges]) - disjoint_indices.append(idx) - - for x in batched_angle_indices: - _, idx, _, batch_id_ang, _, ang_id, _, len_ang = CastBatchedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_indices[0], x, part_edges, part_angle]) - disjoint_angle_indices.append(idx) - - for x in batched_nodes: - disjoint_nodes.append( - CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, part_nodes])[0]) - - for x in batched_edges: - disjoint_edges.append( - CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, part_edges])[0]) - - for x in batched_angles: - disjoint_angles.append( - CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, part_angle])[0]) - - for x in batched_state: - disjoint_state.append( - CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(x)) - - if has_crystal_input > 0: - disjoint_crystal_info.append( - CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_crystal_info[0], part_edges])[0] - ) - disjoint_crystal_info.append( - CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(batched_crystal_info[1]) - ) - if has_crystal_input > 2: - # multiplicity - disjoint_crystal_info.append( - CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_crystal_info[2], part_nodes])[0] - ) - # symmetry operations - disjoint_crystal_info.append( - CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_crystal_info[3], part_edges])[0] - ) + num_mask = max(reduced_mask) + 1 + if len(mask_assignment) + num_mask != len(model_inputs): + raise ValueError() + + values_input = model_inputs[:-num_mask] + mask_input = model_inputs[-num_mask:] + + if index_assignment is None: + index_assignment = [None for _ in range(len(values_input))] + if len(index_assignment) != len(mask_assignment): + raise ValueError() + + out_tensor = [None for _ in range(len(values_input))] + out_batch_id = [None for _ in range(num_mask)] + out_graph_id = [None for _ in range(num_mask)] + out_totals = [None for _ in range(num_mask)] + + for i, i_ref in enumerate(index_assignment): + if i_ref is None: + continue + assert isinstance(i_ref, int), "Must provide positional index of for reference of indices." + ref = values_input[i_ref] + x = values_input[i] + m, m_ref = mask_assignment[i], mask_assignment[i_ref] + ref_mask = mask_input[m_ref] + x_mask = mask_input[m] + o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastBatchedIndicesToDisjoint( + **cast_disjoint_kwargs)([ref, x, ref_mask, x_mask]) + out_tensor[i] = o_x + # Important to no overwrite indices with simple values here. + if out_tensor[i_ref] is None: + out_tensor[i_ref] = o_ref + out_batch_id[m] = b_x + out_batch_id[m_ref] = b_r + out_graph_id[m] = g_x + out_graph_id[m_ref] = g_r + out_totals[m] = t_x + out_totals[m_ref] = t_r + + for i, x in enumerate(values_input): + if out_tensor[i] is not None: + continue + m = mask_assignment[i] + if m is None: + out_tensor[i] = CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(x) + continue + + x_mask = mask_input[m] + o_x, bi, gi, tot = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, x_mask]) + out_tensor[i] = o_x + out_batch_id[m] = bi + out_graph_id[m] = gi + out_totals[m] = tot elif input_tensor_type in ["ragged", "jagged"]: + if index_assignment is None: + index_assignment = [None for _ in range(len(model_inputs))] + if len(index_assignment) != len(model_inputs): + raise ValueError() + + reduced_mask = [x for x in mask_assignment if x is not None] + if len(reduced_mask) == 0: + num_mask = 0 + else: + num_mask = max(reduced_mask) + 1 + + out_tensor = [None for _ in range(len(model_inputs))] + out_batch_id = [None for _ in range(num_mask)] + out_graph_id = [None for _ in range(num_mask)] + out_totals = [None for _ in range(num_mask)] + + for i, i_ref in enumerate(index_assignment): + if i_ref is None: + continue + assert isinstance(i_ref, int), "Must provide positional index of for reference of indices." + ref = model_inputs[i_ref] + x = model_inputs[i] + m, m_ref = mask_assignment[i], mask_assignment[i_ref] + o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastRaggedIndicesToDisjoint( + **cast_disjoint_kwargs)([ref, x]) + out_tensor[i] = o_x + # Important to no overwrite indices with simple values here. + if out_tensor[i_ref] is None: + out_tensor[i_ref] = o_ref + out_batch_id[m] = b_x + out_batch_id[m_ref] = b_r + out_graph_id[m] = g_x + out_graph_id[m_ref] = g_r + out_totals[m] = t_x + out_totals[m_ref] = t_r + + for i, x in enumerate(model_inputs): + if out_tensor[i] is not None: + continue + m = mask_assignment[i] + if m is not None: + o_x, bi, gi, tot = CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x) + out_tensor[i] = o_x + out_batch_id[m] = bi + out_graph_id[m] = gi + out_totals[m] = tot + else: + out_tensor[i] = x - for x in batched_indices: - _, idx, batch_id_node, batch_id_edge, node_id, edge_id, len_nodes, len_edges = CastRaggedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_nodes[0], x]) - disjoint_indices.append(idx) - - for x in batched_angle_indices: - _, idx, _, batch_id_ang, _, ang_id, _, len_ang = CastRaggedIndicesToDisjoint( - **cast_disjoint_kwargs)([batched_indices[0], x]) - disjoint_angle_indices.append(idx) - - for x in batched_nodes: - disjoint_nodes.append( - CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)[0]) - - for x in batched_edges: - disjoint_edges.append( - CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)[0]) - - for x in batched_angles: - disjoint_angles.append( - CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)[0]) - - if has_crystal_input > 0: - disjoint_crystal_info.append( - CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(batched_crystal_info[0])[0] - ) - disjoint_crystal_info.append( - batched_crystal_info[1] - ) - if has_crystal_input > 2: - disjoint_crystal_info.append( - CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(batched_crystal_info[2])[0] - ) - disjoint_crystal_info.append( - CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(batched_crystal_info[3])[0] - ) - - disjoint_state = batched_state + else: + is_already_disjoint = True + if is_already_disjoint: + out = model_inputs else: - disjoint_nodes = batched_nodes - disjoint_edges = batched_edges - disjoint_indices = batched_indices - disjoint_state = batched_state - disjoint_angle_indices = batched_angle_indices - disjoint_angles = batched_angles - disjoint_crystal_info = batched_crystal_info - - if input_tensor_type in ["ragged", "jagged", "padded", "masked"]: - disjoint_id.append(batch_id_node) # noqa - disjoint_id.append(batch_id_edge) # noqa - if int(has_angle_indices) > 0: - disjoint_id.append(batch_id_ang) # noqa + out = out_tensor + out_batch_id if return_sub_id: - disjoint_id.append(node_id) # noqa - disjoint_id.append(edge_id) # noqa - if int(has_angle_indices) > 0: - disjoint_id.append(ang_id) # noqa - disjoint_id.append(len_nodes) # noqa - disjoint_id.append(len_edges) # noqa - if int(has_angle_indices) > 0: - disjoint_id.append(len_ang) # noqa - else: - disjoint_id = batched_id - - disjoint_model_inputs = disjoint_nodes + disjoint_edges + disjoint_angles + disjoint_indices + disjoint_angle_indices + disjoint_state + disjoint_crystal_info + disjoint_id - - return disjoint_model_inputs + out = out + out_graph_id + out = out + out_totals + return out diff --git a/training/hyper/hyper_esol.py b/training/hyper/hyper_esol.py index 0b3d43af..8ea2ecc0 100644 --- a/training/hyper/hyper_esol.py +++ b/training/hyper/hyper_esol.py @@ -1166,4 +1166,74 @@ "kgcnn_version": "4.0.0" } }, + "HDNNP2nd": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.Schnet", + "config": { + "name": "HDNNP2nd", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, + {"shape": (None, 2), "name": "range_indices", "dtype": "int64"}, + {"shape": (None, 3), "name": "angle_indices_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_ranges", "dtype": "int64"}, + {"shape": (), "name": "total_angles", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "cast_disjoint_kwargs": {}, + "w_acsf_ang_kwargs": {}, + "w_acsf_rad_kwargs": {}, + "mlp_kwargs": {"units": [128, 128, 128, 1], + "num_relations": 96, + "activation": ["swish", "swish", "swish", "linear"]}, + "node_pooling_args": {"pooling_method": "sum"}, + "verbose": 10, + "output_embedding": "graph", "output_to_tensor": True, + "use_output_mlp": False, + "output_mlp": {"use_bias": [True, True], "units": [64, 1], + "activation": ["swish", "linear"]} + } + }, + "training": { + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}}, + "fit": { + "batch_size": 64, "epochs": 500, "validation_freq": 10, "verbose": 2, + "callbacks": [ + {"class_name": "kgcnn>LinearLearningRateScheduler", "config": { + "learning_rate_start": 0.001, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 500, + "verbose": 0} + } + ] + }, + "compile": { + "optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, + "loss": "mean_absolute_error" + } + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"map_list": {"method": "set_range", "max_distance": 8, "max_neighbours": 10000}}, + {"map_list": {"method": "set_angle"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_angles", + "count_edges": "angle_indices"}}, + ] + }, + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, }