Skip to content

Commit a6a3cfe

Browse files
committed
refactoring for keras 3.0
1 parent dcc2905 commit a6a3cfe

File tree

13 files changed

+192
-169
lines changed

13 files changed

+192
-169
lines changed

kgcnn/layers/geom.py

Lines changed: 79 additions & 68 deletions
Large diffs are not rendered by default.

kgcnn/literature/DMPNN/_layers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
class DMPNNPPoolingEdgesDirected(ks.layers.Layer): # noqa
88
"""Pooling of edges for around a target node as defined by
9+
910
`DMPNN <https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00237>`__ . This is slightly different as the normal node
1011
aggregation from message passing like networks. Requires edge pair indices for this implementation.
1112
"""

kgcnn/literature/DMPNN/_make.py

Lines changed: 17 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
import keras_core as ks
22
from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint,
33
CastDisjointToGraphState, CastDisjointToBatchedAttributes, CastGraphStateToDisjoint)
4-
from kgcnn.layers.gather import GatherNodesOutgoing, GatherState
5-
from keras_core.layers import Dense, Concatenate, Activation, Add, Dropout
6-
from kgcnn.layers.modules import Embedding
7-
from kgcnn.layers.mlp import GraphMLP, MLP
8-
from kgcnn.layers.aggr import AggregateLocalEdges
9-
from kgcnn.layers.pooling import PoolingNodes
104
from kgcnn.layers.scale import get as get_scaler
11-
from ._layers import DMPNNPPoolingEdgesDirected
125
from kgcnn.models.utils import update_model_kwargs
136
from keras_core.backend import backend as backend_to_use
7+
from ._model import model_disjoint
148

159
# Keep track of model version from commit date in literature.
1610
# To be updated if model is changed in a significant way.
17-
__model_version__ = "2023-09-26"
11+
__model_version__ = "2023-10-23"
1812

1913
# Supported backends
2014
__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"]
15+
2116
if backend_to_use() not in __kgcnn_model_backend_supported__:
2217
raise NotImplementedError("Backend '%s' for model 'DMPNN' is not supported." % backend_to_use())
2318

@@ -82,24 +77,27 @@ def make_model(name: str = None,
8277
Default parameters can be found in :obj:`kgcnn.literature.DMPNN.model_default`.
8378
8479
Inputs:
85-
list: `[node_attributes, edge_attributes, edge_indices, edge_pairs]` or
86-
`[node_attributes, edge_attributes, edge_indices, edge_pairs, state_attributes]` if `use_graph_state=True`
80+
list: `[node_attributes, edge_attributes, edge_indices, edge_pairs, total_nodes, total_edges]` or
81+
`[node_attributes, edge_attributes, edge_indices, edge_pairs, total_nodes, total_edges, state_attributes]`
82+
if `use_graph_state=True` .
8783
88-
- node_attributes (tf.RaggedTensor): Node attributes of shape `(batch, None, F)` or `(batch, None)`
84+
- node_attributes (Tensor): Node attributes of shape `(batch, None, F)` or `(batch, None)`
8985
using an embedding layer.
90-
- edge_attributes (tf.RaggedTensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
86+
- edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
9187
using an embedding layer.
92-
- edge_indices (tf.RaggedTensor): Index list for edges of shape `(batch, None, 2)`.
93-
- edge_pairs (tf.RaggedTensor): Pair mappings for reverse edge for each edge `(batch, None, 1)`.
94-
- state_attributes (tf.Tensor): Environment or graph state attributes of shape `(batch, F)` or `(batch,)`
88+
- edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`.
89+
- edge_pairs (Tensor): Pair mappings for reverse edge for each edge `(batch, None, 1)`.
90+
- state_attributes (Tensor): Environment or graph state attributes of shape `(batch, F)` or `(batch,)`
9591
using an embedding layer.
92+
- total_nodes(Tensor): Number of Nodes in graph of shape `(batch, )` .
93+
- total_edges(Tensor): Number of Edges in graph of shape `(batch, )` .
9694
9795
Outputs:
98-
tf.Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.
96+
Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.
9997
10098
Args:
10199
name (str): Name of the model. Should be "DMPNN".
102-
inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition.
100+
inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition.
103101
cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` .
104102
input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers.
105103
input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers.
@@ -115,9 +113,10 @@ def make_model(name: str = None,
115113
verbose (int): Level for print information.
116114
use_graph_state (bool): Whether to use graph state information. Default is False.
117115
output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph".
118-
output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`.
116+
output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`.
119117
output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block.
120118
Defines number of model outputs and activation.
119+
output_scaling (dict): Kwargs for scaling layer, if scaling layer is to be used.
121120
122121
Returns:
123122
:obj:`keras.models.Model`
@@ -167,71 +166,3 @@ def set_scale(*args, **kwargs):
167166

168167
setattr(model, "set_scale", set_scale)
169168
return model
170-
171-
172-
def model_disjoint(inputs: list = None,
173-
use_node_embedding: bool = None,
174-
use_edge_embedding: bool = None,
175-
use_graph_embedding: bool = None,
176-
input_node_embedding: dict = None,
177-
input_edge_embedding: dict = None,
178-
input_graph_embedding: dict = None,
179-
pooling_args: dict = None,
180-
edge_initialize: dict = None,
181-
edge_dense: dict = None,
182-
edge_activation: dict = None,
183-
node_dense: dict = None,
184-
dropout: dict = None,
185-
depth: int = None,
186-
use_graph_state: bool = False,
187-
output_embedding: str = None,
188-
output_mlp: dict = None):
189-
n, ed, edi, batch_id_node, ed_pairs, count_nodes, graph_state = inputs
190-
191-
# Embedding, if no feature dimension
192-
if use_node_embedding:
193-
n = Embedding(**input_node_embedding)(n)
194-
if use_edge_embedding:
195-
ed = Embedding(**input_edge_embedding)(ed)
196-
if use_graph_state:
197-
if use_graph_embedding:
198-
graph_state = Embedding(**input_graph_embedding)(graph_state)
199-
200-
# Make first edge hidden h0
201-
h_n0 = GatherNodesOutgoing()([n, edi])
202-
h0 = Concatenate(axis=-1)([h_n0, ed])
203-
h0 = Dense(**edge_initialize)(h0)
204-
205-
# One Dense layer for all message steps
206-
edge_dense_all = Dense(**edge_dense) # Should be linear activation
207-
208-
# Model Loop
209-
h = h0
210-
for i in range(depth):
211-
m_vw = DMPNNPPoolingEdgesDirected()([n, h, edi, ed_pairs])
212-
h = edge_dense_all(m_vw)
213-
h = Add()([h, h0])
214-
h = Activation(**edge_activation)(h)
215-
if dropout is not None:
216-
h = Dropout(**dropout)(h)
217-
218-
mv = AggregateLocalEdges(**pooling_args)([n, h, edi])
219-
mv = Concatenate(axis=-1)([mv, n])
220-
hv = Dense(**node_dense)(mv)
221-
222-
# Output embedding choice
223-
n = hv
224-
if output_embedding == 'graph':
225-
out = PoolingNodes(**pooling_args)([count_nodes, n, batch_id_node])
226-
if use_graph_state:
227-
out = ks.layers.Concatenate()([graph_state, out])
228-
out = MLP(**output_mlp)(out)
229-
elif output_embedding == 'node':
230-
if use_graph_state:
231-
graph_state_node = GatherState()([graph_state, n])
232-
n = Concatenate()([n, graph_state_node])
233-
out = GraphMLP(**output_mlp)(n)
234-
else:
235-
raise ValueError("Unsupported embedding mode for `DMPNN`.")
236-
237-
return out

kgcnn/literature/DMPNN/_model.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import keras_core as ks
2+
from keras_core.layers import Concatenate, Dense, Add, Activation, Dropout
3+
from kgcnn.layers.aggr import AggregateLocalEdges
4+
from kgcnn.layers.gather import GatherNodesOutgoing, GatherState
5+
from kgcnn.layers.mlp import MLP, GraphMLP
6+
from kgcnn.layers.modules import Embedding
7+
from kgcnn.layers.pooling import PoolingNodes
8+
from kgcnn.literature.DMPNN._layers import DMPNNPPoolingEdgesDirected
9+
10+
11+
def model_disjoint(inputs: list = None,
12+
use_node_embedding: bool = None,
13+
use_edge_embedding: bool = None,
14+
use_graph_embedding: bool = None,
15+
input_node_embedding: dict = None,
16+
input_edge_embedding: dict = None,
17+
input_graph_embedding: dict = None,
18+
pooling_args: dict = None,
19+
edge_initialize: dict = None,
20+
edge_dense: dict = None,
21+
edge_activation: dict = None,
22+
node_dense: dict = None,
23+
dropout: dict = None,
24+
depth: int = None,
25+
use_graph_state: bool = False,
26+
output_embedding: str = None,
27+
output_mlp: dict = None):
28+
n, ed, edi, batch_id_node, ed_pairs, count_nodes, graph_state = inputs
29+
30+
# Embedding, if no feature dimension
31+
if use_node_embedding:
32+
n = Embedding(**input_node_embedding)(n)
33+
if use_edge_embedding:
34+
ed = Embedding(**input_edge_embedding)(ed)
35+
if use_graph_state:
36+
if use_graph_embedding:
37+
graph_state = Embedding(**input_graph_embedding)(graph_state)
38+
39+
# Make first edge hidden h0
40+
h_n0 = GatherNodesOutgoing()([n, edi])
41+
h0 = Concatenate(axis=-1)([h_n0, ed])
42+
h0 = Dense(**edge_initialize)(h0)
43+
44+
# One Dense layer for all message steps
45+
edge_dense_all = Dense(**edge_dense) # Should be linear activation
46+
47+
# Model Loop
48+
h = h0
49+
for i in range(depth):
50+
m_vw = DMPNNPPoolingEdgesDirected()([n, h, edi, ed_pairs])
51+
h = edge_dense_all(m_vw)
52+
h = Add()([h, h0])
53+
h = Activation(**edge_activation)(h)
54+
if dropout is not None:
55+
h = Dropout(**dropout)(h)
56+
57+
mv = AggregateLocalEdges(**pooling_args)([n, h, edi])
58+
mv = Concatenate(axis=-1)([mv, n])
59+
hv = Dense(**node_dense)(mv)
60+
61+
# Output embedding choice
62+
n = hv
63+
if output_embedding == 'graph':
64+
out = PoolingNodes(**pooling_args)([count_nodes, n, batch_id_node])
65+
if use_graph_state:
66+
out = ks.layers.Concatenate()([graph_state, out])
67+
out = MLP(**output_mlp)(out)
68+
elif output_embedding == 'node':
69+
if use_graph_state:
70+
graph_state_node = GatherState()([graph_state, n])
71+
n = Concatenate()([n, graph_state_node])
72+
out = GraphMLP(**output_mlp)(n)
73+
else:
74+
raise ValueError("Unsupported embedding mode for `DMPNN`.")
75+
76+
return out

kgcnn/literature/GAT/_make.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def make_model(inputs: list = None,
8080
- edge_attributes (Tensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
8181
using an embedding layer.
8282
- edge_indices (Tensor): Index list for edges of shape `(batch, None, 2)`.
83-
- total_nodes(Tensor, optional): Number of Nodes in graph if not same sized graphs of shape `(batch, )` .
84-
- total_edges(Tensor, optional): Number of Edges in graph if not same sized graphs of shape `(batch, )` .
83+
- total_nodes(Tensor): Number of Nodes in graph if not same sized graphs of shape `(batch, )` .
84+
- total_edges(Tensor): Number of Edges in graph if not same sized graphs of shape `(batch, )` .
8585
8686
Outputs:
8787
Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`.

kgcnn/literature/GAT/_model.py

Whitespace-only changes.

kgcnn/literature/GATv2/_model.py

Whitespace-only changes.

kgcnn/literature/GCN/_mode.py

Whitespace-only changes.

kgcnn/literature/GraphSAGE/_model.py

Whitespace-only changes.

kgcnn/losses/losses.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def __init__(self, reduction="sum_over_batch_size", name="force_mean_absolute_er
3131
def call(self, y_true, y_pred):
3232
# Shape: (batch, N, 3, S)
3333
if self.find_padded_atoms:
34-
check_nonzero = ops.logical_not(
35-
ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=2))
36-
row_count = ops.cast(ops.sum(check_nonzero, axis=1), dtype=y_true.dtype)
37-
norm = 1/row_count
38-
norm = ops.where(ops.isnan(norm), 1., norm)
34+
check_nonzero = ops.cast(ops.logical_not(
35+
ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=2)), dtype="int32")
36+
row_count = ops.sum(check_nonzero, axis=1)
37+
row_count = ops.where(row_count < 1, ops.cast(ops.shape(y_true)[1], dtype=row_count.dtype), row_count)
38+
norm = 1/ops.cast(row_count, dtype=y_true.dtype)
3939
else:
4040
norm = 1/ops.shape(y_true)[1]
4141

0 commit comments

Comments
 (0)