diff --git a/AUTHORS b/AUTHORS index 43b03f5f..f7e21c8d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -2,4 +2,5 @@ List of contributors to kgcnn modules. - GNNExplainer module by robinruff - DGIN by thegodone +- rGIN by thegodone diff --git a/kgcnn/literature/GIN/_make.py b/kgcnn/literature/GIN/_make.py index 3b4bac1d..cfafeded 100644 --- a/kgcnn/literature/GIN/_make.py +++ b/kgcnn/literature/GIN/_make.py @@ -44,7 +44,8 @@ "output_to_tensor": None, # deprecated "output_tensor_type": "padded", "output_mlp": {"use_bias": True, "units": 1, - "activation": "softmax"} + "activation": "softmax"}, + "output_scaling": None, } diff --git a/kgcnn/literature/INorp/_make.py b/kgcnn/literature/INorp/_make.py index 7213646e..54c34242 100644 --- a/kgcnn/literature/INorp/_make.py +++ b/kgcnn/literature/INorp/_make.py @@ -51,6 +51,7 @@ 'output_embedding': 'graph', "output_to_tensor": None, # deprecated "output_tensor_type": "padded", + "output_scaling": None, 'output_mlp': { "use_bias": [True, True, False], "units": [25, 10, 1], "activation": ['relu', 'relu', 'sigmoid']} @@ -130,8 +131,8 @@ def make_model(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - mask_assignment=[0, 1, 1], - index_assignment=[None, None, 0] + mask_assignment=[0, 1, 1, None], + index_assignment=[None, None, 0, None] ) n, ed, disjoint_indices, gs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs diff --git a/kgcnn/literature/rGIN/_layers.py b/kgcnn/literature/rGIN/_layers.py new file mode 100644 index 00000000..f27ea686 --- /dev/null +++ b/kgcnn/literature/rGIN/_layers.py @@ -0,0 +1,94 @@ +import keras as ks +from keras import ops +from keras.layers import Layer +from kgcnn.layers.gather import GatherNodesOutgoing +from kgcnn.layers.aggr import AggregateLocalEdges +from keras.layers import Add, Concatenate + + +class rGIN(Layer): + r"""Random Features Strengthen Graph Neural Networks `__ . + + Computes graph convolution at step :math:`k` for node embeddings :math:`h_\nu` as: + + .. math:: + h_\nu^{(k)} = \phi^{(k)} ((1+\epsilon^{(k)}) h_\nu^{k-1} + \sum_{u\in N(\nu)}) h_u^{k-1}. + + with optional learnable :math:`\epsilon^{(k)}` + + .. note:: + + The non-linear mapping :math:`\phi^{(k)}`, usually an :obj:`MLP`, is not included in this layer. + """ + + def __init__(self, + pooling_method='sum', + epsilon_learnable=False, + random_range=100, + **kwargs): + """Initialize layer. + + Args: + epsilon_learnable (bool): If epsilon is learnable or just constant zero. Default is False. + pooling_method (str): Pooling method for summing edges. Default is 'segment_sum'. + """ + super(rGIN, self).__init__(**kwargs) + self.pooling_method = pooling_method + self.epsilon_learnable = epsilon_learnable + self.random_range = random_range + + # Layers + self.lay_gather = GatherNodesOutgoing() + self.lay_pool = AggregateLocalEdges(pooling_method=self.pooling_method) + self.lay_add = Add() + self.lay_concat = Concatenate() + + # Epsilon with trainable as optional and default zeros initialized. + self.eps_k = self.add_weight( + name="epsilon_k", + shape=(), + trainable=self.epsilon_learnable, + initializer="zeros", + dtype=self.dtype + ) + + def build(self, input_shape): + """Build layer.""" + super(rGIN, self).build(input_shape) + + def call(self, inputs, **kwargs): + r"""Forward pass. + + Args: + inputs: [nodes, edge_index] + + - nodes (Tensor): Node embeddings of shape `([N], F)` + - edge_index (Tensor): Edge indices referring to nodes of shape `(2, [M])` + + Returns: + Tensor: Node embeddings of shape `([N], F)` + """ + node, edge_index = inputs + node_shape = ops.shape(node) + random_values = ops.cast( + ks.random.uniform([node_shape[0], 1], maxval=self.random_range, dtype="int32"), + self.dtype) / self.random_range, + + node = self.lay_concat([node, random_values]) + + ed = self.lay_gather([node, edge_index], **kwargs) + nu = self.lay_pool([node, ed, edge_index], **kwargs) # Summing for each node connection + no = (1 + self.eps_k) * node + + # Add the random features to the original features + out = self.lay_add([node, nu], **kwargs) + + return out + + def get_config(self): + """Update config.""" + config = super(rGIN, self).get_config() + config.update({"pooling_method": self.pooling_method, + "epsilon_learnable": self.epsilon_learnable, + "random_range": self.random_range}) + return config diff --git a/kgcnn/literature/rGIN/_make.py b/kgcnn/literature/rGIN/_make.py new file mode 100644 index 00000000..2bbc62a3 --- /dev/null +++ b/kgcnn/literature/rGIN/_make.py @@ -0,0 +1,144 @@ +# import keras_core as ks +from ._model import model_disjoint +from kgcnn.models.utils import update_model_kwargs +from kgcnn.layers.scale import get as get_scaler +from kgcnn.layers.modules import Input +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from keras.backend import backend as backend_to_use +from kgcnn.ops.activ import * + +# Keep track of model version from commit date in literature. +# To be updated if model is changed in a significant way. +__model_version__ = "2023-10-12" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'rGIN' is not supported." % backend_to_use()) + +# Implementation of rGIN in `keras` from paper: +# Random Features Strengthen Graph Neural Networks +# Ryoma Sato, Makoto Yamada, Hisashi Kashima +# https://arxiv.org/abs/2002.03155 + +model_default = { + "name": "rGIN", + "inputs": [ + {"shape": (None,), "name": "node_attributes", "dtype": "float32", "ragged": True}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True} + ], + "input_tensor_type": "padded", + "cast_disjoint_kwargs": {}, + "input_embedding": None, # deprecated + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"], + "use_normalization": True, "normalization_technique": "graph_batch"}, + "rgin_args": {"random_range": 100}, + "depth": 3, "dropout": 0.0, "verbose": 10, + "last_mlp": {"use_bias": [True, True, True], "units": [64, 64, 64], + "activation": ["relu", "relu", "linear"]}, + "output_embedding": 'graph', + "output_mlp": {"use_bias": True, "units": 1, + "activation": "softmax"}, + "output_to_tensor": None, # deprecated + "output_tensor_type": "padded", + "output_scaling": None, +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + depth: int = None, + rgin_args: dict = None, + gin_mlp: dict = None, + last_mlp: dict = None, + dropout: float = None, + name: str = None, # noqa + verbose: int = None, # noqa + output_embedding: str = None, + output_to_tensor: bool = None, # noqa + output_mlp: dict = None, + output_scaling: dict = None, + output_tensor_type: str = None, + ): + r"""Make `rGIN `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.rGIN.model_default` . + + + + Args: + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for castin layers. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. + depth (int): Number of graph embedding units or depth of the network. + rgin_args (dict): Dictionary of layer arguments unpacked in :obj:`GIN` convolutional layer. + gin_mlp (dict): Dictionary of layer arguments unpacked in :obj:`MLP` for convolutional layer. + last_mlp (dict): Dictionary of layer arguments unpacked in last :obj:`MLP` layer before output or pooling. + dropout (float): Dropout to use. + name (str): Name of the model. + verbose (int): Level of print output. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + # Make input + model_inputs = [Input(**x) for x in inputs] + + disjoint_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1], + index_assignment=[None, 0] + ) + + n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs + + # Wrapping disjoint model. + out = model_disjoint( + [n, disjoint_indices, batch_id_node, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + input_node_embedding=input_node_embedding, + gin_mlp=gin_mlp, + depth=depth, + rgin_args=rgin_args, + last_mlp=last_mlp, + output_mlp=output_mlp, + output_embedding=output_embedding, + dropout=dropout + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/rGIN/_model.py b/kgcnn/literature/rGIN/_model.py new file mode 100644 index 00000000..0a0aab7e --- /dev/null +++ b/kgcnn/literature/rGIN/_model.py @@ -0,0 +1,50 @@ +from kgcnn.layers.mlp import GraphMLP, MLP +from kgcnn.layers.modules import Embedding +from keras.layers import Dense, Dropout, Add +from kgcnn.layers.pooling import PoolingNodes +from ._layers import rGIN + + +def model_disjoint( + inputs, + use_node_embedding, + input_node_embedding, + gin_mlp, + depth, + rgin_args, + last_mlp, + output_mlp, + output_embedding, + dropout +): + n, edi, batch_id_node, count_nodes = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(n) + + # Model + # Map to the required number of units. + n_units = gin_mlp["units"][-1] if isinstance(gin_mlp["units"], list) else int(gin_mlp["units"]) + n = Dense(n_units, use_bias=True, activation='linear')(n) + list_embeddings = [n] + for i in range(0, depth): + n = rGIN(**rgin_args)([n, edi]) + n = GraphMLP(**gin_mlp)([n, batch_id_node, count_nodes]) + list_embeddings.append(n) + + # Output embedding choice + if output_embedding == "graph": + out = [PoolingNodes()([count_nodes, x, batch_id_node]) for x in list_embeddings] # will return tensor + out = [MLP(**last_mlp)(x) for x in out] + out = [Dropout(dropout)(x) for x in out] + out = Add()(out) + out = MLP(**output_mlp)(out) + elif output_embedding == "node": # Node labeling + out = n + out = GraphMLP(**last_mlp)([out, batch_id_node, count_nodes]) + out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported output embedding for mode `rGIN`") + + return out