Skip to content

Commit cfdac31

Browse files
committed
Added rGIN (not tested)
1 parent a3fbe86 commit cfdac31

File tree

6 files changed

+294
-3
lines changed

6 files changed

+294
-3
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ List of contributors to kgcnn modules.
22

33
- GNNExplainer module by robinruff
44
- DGIN by thegodone
5+
- rGIN by thegodone
56

kgcnn/literature/GIN/_make.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
"output_to_tensor": None, # deprecated
4545
"output_tensor_type": "padded",
4646
"output_mlp": {"use_bias": True, "units": 1,
47-
"activation": "softmax"}
47+
"activation": "softmax"},
48+
"output_scaling": None,
4849
}
4950

5051

kgcnn/literature/INorp/_make.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
'output_embedding': 'graph',
5252
"output_to_tensor": None, # deprecated
5353
"output_tensor_type": "padded",
54+
"output_scaling": None,
5455
'output_mlp': {
5556
"use_bias": [True, True, False], "units": [25, 10, 1],
5657
"activation": ['relu', 'relu', 'sigmoid']}
@@ -130,8 +131,8 @@ def make_model(inputs: list = None,
130131
model_inputs,
131132
input_tensor_type=input_tensor_type,
132133
cast_disjoint_kwargs=cast_disjoint_kwargs,
133-
mask_assignment=[0, 1, 1],
134-
index_assignment=[None, None, 0]
134+
mask_assignment=[0, 1, 1, None],
135+
index_assignment=[None, None, 0, None]
135136
)
136137

137138
n, ed, disjoint_indices, gs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs

kgcnn/literature/rGIN/_layers.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import keras as ks
2+
from keras import ops
3+
from keras.layers import Layer
4+
from kgcnn.layers.gather import GatherNodesOutgoing
5+
from kgcnn.layers.aggr import AggregateLocalEdges
6+
from keras.layers import Add, Concatenate
7+
8+
9+
class rGIN(Layer):
10+
r"""Random Features Strengthen Graph Neural Networks <https://arxiv.org/abs/2002.03155>`__ .
11+
12+
Computes graph convolution at step :math:`k` for node embeddings :math:`h_\nu` as:
13+
14+
.. math::
15+
h_\nu^{(k)} = \phi^{(k)} ((1+\epsilon^{(k)}) h_\nu^{k-1} + \sum_{u\in N(\nu)}) h_u^{k-1}.
16+
17+
with optional learnable :math:`\epsilon^{(k)}`
18+
19+
.. note::
20+
21+
The non-linear mapping :math:`\phi^{(k)}`, usually an :obj:`MLP`, is not included in this layer.
22+
"""
23+
24+
def __init__(self,
25+
pooling_method='sum',
26+
epsilon_learnable=False,
27+
random_range=100,
28+
**kwargs):
29+
"""Initialize layer.
30+
31+
Args:
32+
epsilon_learnable (bool): If epsilon is learnable or just constant zero. Default is False.
33+
pooling_method (str): Pooling method for summing edges. Default is 'segment_sum'.
34+
"""
35+
super(rGIN, self).__init__(**kwargs)
36+
self.pooling_method = pooling_method
37+
self.epsilon_learnable = epsilon_learnable
38+
self.random_range = random_range
39+
40+
# Layers
41+
self.lay_gather = GatherNodesOutgoing()
42+
self.lay_pool = AggregateLocalEdges(pooling_method=self.pooling_method)
43+
self.lay_add = Add()
44+
self.lay_concat = Concatenate()
45+
46+
# Epsilon with trainable as optional and default zeros initialized.
47+
self.eps_k = self.add_weight(
48+
name="epsilon_k",
49+
shape=(),
50+
trainable=self.epsilon_learnable,
51+
initializer="zeros",
52+
dtype=self.dtype
53+
)
54+
55+
def build(self, input_shape):
56+
"""Build layer."""
57+
super(rGIN, self).build(input_shape)
58+
59+
def call(self, inputs, **kwargs):
60+
r"""Forward pass.
61+
62+
Args:
63+
inputs: [nodes, edge_index]
64+
65+
- nodes (Tensor): Node embeddings of shape `([N], F)`
66+
- edge_index (Tensor): Edge indices referring to nodes of shape `(2, [M])`
67+
68+
Returns:
69+
Tensor: Node embeddings of shape `([N], F)`
70+
"""
71+
node, edge_index = inputs
72+
node_shape = ops.shape(node)
73+
random_values = ops.cast(
74+
ks.random.uniform([node_shape[0], 1], maxval=self.random_range, dtype="int32"),
75+
self.dtype) / self.random_range,
76+
77+
node = self.lay_concat([node, random_values])
78+
79+
ed = self.lay_gather([node, edge_index], **kwargs)
80+
nu = self.lay_pool([node, ed, edge_index], **kwargs) # Summing for each node connection
81+
no = (1 + self.eps_k) * node
82+
83+
# Add the random features to the original features
84+
out = self.lay_add([node, nu], **kwargs)
85+
86+
return out
87+
88+
def get_config(self):
89+
"""Update config."""
90+
config = super(rGIN, self).get_config()
91+
config.update({"pooling_method": self.pooling_method,
92+
"epsilon_learnable": self.epsilon_learnable,
93+
"random_range": self.random_range})
94+
return config

kgcnn/literature/rGIN/_make.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# import keras_core as ks
2+
from ._model import model_disjoint
3+
from kgcnn.models.utils import update_model_kwargs
4+
from kgcnn.layers.scale import get as get_scaler
5+
from kgcnn.layers.modules import Input
6+
from kgcnn.models.casting import template_cast_output, template_cast_list_input
7+
from keras.backend import backend as backend_to_use
8+
from kgcnn.ops.activ import *
9+
10+
# Keep track of model version from commit date in literature.
11+
# To be updated if model is changed in a significant way.
12+
__model_version__ = "2023-10-12"
13+
14+
# Supported backends
15+
__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"]
16+
if backend_to_use() not in __kgcnn_model_backend_supported__:
17+
raise NotImplementedError("Backend '%s' for model 'rGIN' is not supported." % backend_to_use())
18+
19+
# Implementation of rGIN in `keras` from paper:
20+
# Random Features Strengthen Graph Neural Networks
21+
# Ryoma Sato, Makoto Yamada, Hisashi Kashima
22+
# https://arxiv.org/abs/2002.03155
23+
24+
model_default = {
25+
"name": "rGIN",
26+
"inputs": [
27+
{"shape": (None,), "name": "node_attributes", "dtype": "float32", "ragged": True},
28+
{"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True}
29+
],
30+
"input_tensor_type": "padded",
31+
"cast_disjoint_kwargs": {},
32+
"input_embedding": None, # deprecated
33+
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
34+
"gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"],
35+
"use_normalization": True, "normalization_technique": "graph_batch"},
36+
"rgin_args": {"random_range": 100},
37+
"depth": 3, "dropout": 0.0, "verbose": 10,
38+
"last_mlp": {"use_bias": [True, True, True], "units": [64, 64, 64],
39+
"activation": ["relu", "relu", "linear"]},
40+
"output_embedding": 'graph',
41+
"output_mlp": {"use_bias": True, "units": 1,
42+
"activation": "softmax"},
43+
"output_to_tensor": None, # deprecated
44+
"output_tensor_type": "padded",
45+
"output_scaling": None,
46+
}
47+
48+
49+
@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"])
50+
def make_model(inputs: list = None,
51+
input_tensor_type: str = None,
52+
cast_disjoint_kwargs: dict = None,
53+
input_embedding: dict = None, # noqa
54+
input_node_embedding: dict = None,
55+
depth: int = None,
56+
rgin_args: dict = None,
57+
gin_mlp: dict = None,
58+
last_mlp: dict = None,
59+
dropout: float = None,
60+
name: str = None, # noqa
61+
verbose: int = None, # noqa
62+
output_embedding: str = None,
63+
output_to_tensor: bool = None, # noqa
64+
output_mlp: dict = None,
65+
output_scaling: dict = None,
66+
output_tensor_type: str = None,
67+
):
68+
r"""Make `rGIN <https://arxiv.org/abs/2002.03155>`__ graph network via functional API.
69+
Default parameters can be found in :obj:`kgcnn.literature.rGIN.model_default` .
70+
71+
72+
73+
Args:
74+
inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition.
75+
input_tensor_type (str): Input type of graph tensor. Default is "padded".
76+
cast_disjoint_kwargs (dict): Dictionary of arguments for castin layers.
77+
input_embedding (dict): Deprecated in favour of input_node_embedding etc.
78+
input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers.
79+
depth (int): Number of graph embedding units or depth of the network.
80+
rgin_args (dict): Dictionary of layer arguments unpacked in :obj:`GIN` convolutional layer.
81+
gin_mlp (dict): Dictionary of layer arguments unpacked in :obj:`MLP` for convolutional layer.
82+
last_mlp (dict): Dictionary of layer arguments unpacked in last :obj:`MLP` layer before output or pooling.
83+
dropout (float): Dropout to use.
84+
name (str): Name of the model.
85+
verbose (int): Level of print output.
86+
output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph".
87+
output_to_tensor (bool): Deprecated in favour of `output_tensor_type` .
88+
output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block.
89+
Defines number of model outputs and activation.
90+
output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None.
91+
output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded".
92+
93+
Returns:
94+
:obj:`keras.models.Model`
95+
"""
96+
# Make input
97+
# Make input
98+
model_inputs = [Input(**x) for x in inputs]
99+
100+
disjoint_inputs = template_cast_list_input(
101+
model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs,
102+
mask_assignment=[0, 1],
103+
index_assignment=[None, 0]
104+
)
105+
106+
n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs
107+
108+
# Wrapping disjoint model.
109+
out = model_disjoint(
110+
[n, disjoint_indices, batch_id_node, count_nodes],
111+
use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False,
112+
input_node_embedding=input_node_embedding,
113+
gin_mlp=gin_mlp,
114+
depth=depth,
115+
rgin_args=rgin_args,
116+
last_mlp=last_mlp,
117+
output_mlp=output_mlp,
118+
output_embedding=output_embedding,
119+
dropout=dropout
120+
)
121+
122+
if output_scaling is not None:
123+
scaler = get_scaler(output_scaling["name"])(**output_scaling)
124+
out = scaler(out)
125+
126+
# Output embedding choice
127+
out = template_cast_output(
128+
[out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges],
129+
output_embedding=output_embedding, output_tensor_type=output_tensor_type,
130+
input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs
131+
)
132+
133+
model = ks.models.Model(inputs=model_inputs, outputs=out, name=name)
134+
model.__kgcnn_model_version__ = __model_version__
135+
136+
if output_scaling is not None:
137+
def set_scale(*args, **kwargs):
138+
scaler.set_scale(*args, **kwargs)
139+
140+
setattr(model, "set_scale", set_scale)
141+
return model
142+
143+
144+
make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__)

kgcnn/literature/rGIN/_model.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from kgcnn.layers.mlp import GraphMLP, MLP
2+
from kgcnn.layers.modules import Embedding
3+
from keras.layers import Dense, Dropout, Add
4+
from kgcnn.layers.pooling import PoolingNodes
5+
from ._layers import rGIN
6+
7+
8+
def model_disjoint(
9+
inputs,
10+
use_node_embedding,
11+
input_node_embedding,
12+
gin_mlp,
13+
depth,
14+
rgin_args,
15+
last_mlp,
16+
output_mlp,
17+
output_embedding,
18+
dropout
19+
):
20+
n, edi, batch_id_node, count_nodes = inputs
21+
22+
# Embedding, if no feature dimension
23+
if use_node_embedding:
24+
n = Embedding(**input_node_embedding)(n)
25+
26+
# Model
27+
# Map to the required number of units.
28+
n_units = gin_mlp["units"][-1] if isinstance(gin_mlp["units"], list) else int(gin_mlp["units"])
29+
n = Dense(n_units, use_bias=True, activation='linear')(n)
30+
list_embeddings = [n]
31+
for i in range(0, depth):
32+
n = rGIN(**rgin_args)([n, edi])
33+
n = GraphMLP(**gin_mlp)([n, batch_id_node, count_nodes])
34+
list_embeddings.append(n)
35+
36+
# Output embedding choice
37+
if output_embedding == "graph":
38+
out = [PoolingNodes()([count_nodes, x, batch_id_node]) for x in list_embeddings] # will return tensor
39+
out = [MLP(**last_mlp)(x) for x in out]
40+
out = [Dropout(dropout)(x) for x in out]
41+
out = Add()(out)
42+
out = MLP(**output_mlp)(out)
43+
elif output_embedding == "node": # Node labeling
44+
out = n
45+
out = GraphMLP(**last_mlp)([out, batch_id_node, count_nodes])
46+
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
47+
else:
48+
raise ValueError("Unsupported output embedding for mode `rGIN`")
49+
50+
return out

0 commit comments

Comments
 (0)