1
1
import keras_core as ks
2
2
from kgcnn .layers .casting import (CastBatchedIndicesToDisjoint , CastBatchedAttributesToDisjoint ,
3
3
CastDisjointToGraphState , CastDisjointToBatchedAttributes , CastGraphStateToDisjoint )
4
- from kgcnn .layers .gather import GatherNodesOutgoing , GatherState
5
- from keras_core .layers import Dense , Concatenate , Activation , Add , Dropout
6
- from kgcnn .layers .modules import Embedding
7
- from kgcnn .layers .mlp import GraphMLP , MLP
8
- from kgcnn .layers .aggr import AggregateLocalEdges
9
- from kgcnn .layers .pooling import PoolingNodes
10
4
from kgcnn .layers .scale import get as get_scaler
11
- from ._layers import DMPNNPPoolingEdgesDirected
12
5
from kgcnn .models .utils import update_model_kwargs
13
6
from keras_core .backend import backend as backend_to_use
7
+ from ._model import model_disjoint
14
8
15
9
# Keep track of model version from commit date in literature.
16
10
# To be updated if model is changed in a significant way.
17
- __model_version__ = "2023-09-26 "
11
+ __model_version__ = "2023-10-23 "
18
12
19
13
# Supported backends
20
14
__kgcnn_model_backend_supported__ = ["tensorflow" , "torch" , "jax" ]
15
+
21
16
if backend_to_use () not in __kgcnn_model_backend_supported__ :
22
17
raise NotImplementedError ("Backend '%s' for model 'DMPNN' is not supported." % backend_to_use ())
23
18
@@ -82,24 +77,27 @@ def make_model(name: str = None,
82
77
Default parameters can be found in :obj:`kgcnn.literature.DMPNN.model_default`.
83
78
84
79
Inputs:
85
- list: `[node_attributes, edge_attributes, edge_indices, edge_pairs]` or
86
- `[node_attributes, edge_attributes, edge_indices, edge_pairs, state_attributes]` if `use_graph_state=True`
80
+ list: `[node_attributes, edge_attributes, edge_indices, edge_pairs, total_nodes, total_edges]` or
81
+ `[node_attributes, edge_attributes, edge_indices, edge_pairs, total_nodes, total_edges, state_attributes]`
82
+ if `use_graph_state=True` .
87
83
88
- - node_attributes (tf.RaggedTensor ): Node attributes of shape `(batch, None, F)` or `(batch, None)`
84
+ - node_attributes (Tensor ): Node attributes of shape `(batch, None, F)` or `(batch, None)`
89
85
using an embedding layer.
90
- - edge_attributes (tf.RaggedTensor ): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
86
+ - edge_attributes (Tensor ): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
91
87
using an embedding layer.
92
- - edge_indices (tf.RaggedTensor ): Index list for edges of shape `(batch, None, 2)`.
93
- - edge_pairs (tf.RaggedTensor ): Pair mappings for reverse edge for each edge `(batch, None, 1)`.
94
- - state_attributes (tf. Tensor): Environment or graph state attributes of shape `(batch, F)` or `(batch,)`
88
+ - edge_indices (Tensor ): Index list for edges of shape `(batch, None, 2)`.
89
+ - edge_pairs (Tensor ): Pair mappings for reverse edge for each edge `(batch, None, 1)`.
90
+ - state_attributes (Tensor): Environment or graph state attributes of shape `(batch, F)` or `(batch,)`
95
91
using an embedding layer.
92
+ - total_nodes(Tensor): Number of Nodes in graph of shape `(batch, )` .
93
+ - total_edges(Tensor): Number of Edges in graph of shape `(batch, )` .
96
94
97
95
Outputs:
98
- tf. Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.
96
+ Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.
99
97
100
98
Args:
101
99
name (str): Name of the model. Should be "DMPNN".
102
- inputs (list): List of dictionaries unpacked in :obj:`tf. keras.layers.Input`. Order must match model definition.
100
+ inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition.
103
101
cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` .
104
102
input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers.
105
103
input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers.
@@ -115,9 +113,10 @@ def make_model(name: str = None,
115
113
verbose (int): Level for print information.
116
114
use_graph_state (bool): Whether to use graph state information. Default is False.
117
115
output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph".
118
- output_to_tensor (bool): Whether to cast model output to :obj:`tf. Tensor`.
116
+ output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`.
119
117
output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block.
120
118
Defines number of model outputs and activation.
119
+ output_scaling (dict): Kwargs for scaling layer, if scaling layer is to be used.
121
120
122
121
Returns:
123
122
:obj:`keras.models.Model`
@@ -167,71 +166,3 @@ def set_scale(*args, **kwargs):
167
166
168
167
setattr (model , "set_scale" , set_scale )
169
168
return model
170
-
171
-
172
- def model_disjoint (inputs : list = None ,
173
- use_node_embedding : bool = None ,
174
- use_edge_embedding : bool = None ,
175
- use_graph_embedding : bool = None ,
176
- input_node_embedding : dict = None ,
177
- input_edge_embedding : dict = None ,
178
- input_graph_embedding : dict = None ,
179
- pooling_args : dict = None ,
180
- edge_initialize : dict = None ,
181
- edge_dense : dict = None ,
182
- edge_activation : dict = None ,
183
- node_dense : dict = None ,
184
- dropout : dict = None ,
185
- depth : int = None ,
186
- use_graph_state : bool = False ,
187
- output_embedding : str = None ,
188
- output_mlp : dict = None ):
189
- n , ed , edi , batch_id_node , ed_pairs , count_nodes , graph_state = inputs
190
-
191
- # Embedding, if no feature dimension
192
- if use_node_embedding :
193
- n = Embedding (** input_node_embedding )(n )
194
- if use_edge_embedding :
195
- ed = Embedding (** input_edge_embedding )(ed )
196
- if use_graph_state :
197
- if use_graph_embedding :
198
- graph_state = Embedding (** input_graph_embedding )(graph_state )
199
-
200
- # Make first edge hidden h0
201
- h_n0 = GatherNodesOutgoing ()([n , edi ])
202
- h0 = Concatenate (axis = - 1 )([h_n0 , ed ])
203
- h0 = Dense (** edge_initialize )(h0 )
204
-
205
- # One Dense layer for all message steps
206
- edge_dense_all = Dense (** edge_dense ) # Should be linear activation
207
-
208
- # Model Loop
209
- h = h0
210
- for i in range (depth ):
211
- m_vw = DMPNNPPoolingEdgesDirected ()([n , h , edi , ed_pairs ])
212
- h = edge_dense_all (m_vw )
213
- h = Add ()([h , h0 ])
214
- h = Activation (** edge_activation )(h )
215
- if dropout is not None :
216
- h = Dropout (** dropout )(h )
217
-
218
- mv = AggregateLocalEdges (** pooling_args )([n , h , edi ])
219
- mv = Concatenate (axis = - 1 )([mv , n ])
220
- hv = Dense (** node_dense )(mv )
221
-
222
- # Output embedding choice
223
- n = hv
224
- if output_embedding == 'graph' :
225
- out = PoolingNodes (** pooling_args )([count_nodes , n , batch_id_node ])
226
- if use_graph_state :
227
- out = ks .layers .Concatenate ()([graph_state , out ])
228
- out = MLP (** output_mlp )(out )
229
- elif output_embedding == 'node' :
230
- if use_graph_state :
231
- graph_state_node = GatherState ()([graph_state , n ])
232
- n = Concatenate ()([n , graph_state_node ])
233
- out = GraphMLP (** output_mlp )(n )
234
- else :
235
- raise ValueError ("Unsupported embedding mode for `DMPNN`." )
236
-
237
- return out
0 commit comments