diff --git a/AUTHORS b/AUTHORS index f7e21c8d..3f5fb9b4 100644 --- a/AUTHORS +++ b/AUTHORS @@ -3,4 +3,5 @@ List of contributors to kgcnn modules. - GNNExplainer module by robinruff - DGIN by thegodone - rGIN by thegodone - +- MEGAN by the16thpythonist +- MoGAT by thegodone \ No newline at end of file diff --git a/README.md b/README.md index 63ef58d7..9a0cee15 100644 --- a/README.md +++ b/README.md @@ -185,8 +185,6 @@ original implementations (with proper licencing). * **[rGIN](kgcnn/literature/rGIN)** [Random Features Strengthen Graph Neural Networks](https://arxiv.org/abs/2002.03155) by Sato et al. (2020) * **[Schnet](kgcnn/literature/Schnet)**: [SchNet – A deep learning architecture for molecules and materials ](https://aip.scitation.org/doi/10.1063/1.5019779) by Schütt et al. (2017) - -To be completed ... diff --git a/kgcnn/layers/attention.py b/kgcnn/layers/attention.py index d4f0aa64..e8522b91 100644 --- a/kgcnn/layers/attention.py +++ b/kgcnn/layers/attention.py @@ -198,9 +198,9 @@ def call(self, inputs, **kwargs): Args: inputs (list): of [node, edges, edge_indices] - - nodes (Tensor): Node embeddings of shape (batch, [N], F) - - edges (Tensor): Edge or message embeddings of shape (batch, [M], F) - - edge_indices (Tensor): Edge indices referring to nodes of shape (batch, [M], 2) + - nodes (Tensor): Node embeddings of shape ([N], F) + - edges (Tensor): Edge or message embeddings of shape ([M], F) + - edge_indices (Tensor): Edge indices referring to nodes of shape (2, [M]) Returns: Tensor: Embedding tensor of pooled edge attentions for each node. diff --git a/kgcnn/literature/DimeNetPP/_make.py b/kgcnn/literature/DimeNetPP/_make.py index f268cffa..29e9377c 100644 --- a/kgcnn/literature/DimeNetPP/_make.py +++ b/kgcnn/literature/DimeNetPP/_make.py @@ -87,7 +87,7 @@ def make_model(inputs: list = None, """Make `DimeNetPP `_ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.DimeNetPP.model_default`. - Model inputs: + **Model inputs**: Model uses the list template of inputs and standard output template. The supported inputs are :obj:`[nodes, coordinates, edge_indices, angle_indices...]` with '...' indicating mask or ID tensors following the template below. @@ -95,7 +95,7 @@ def make_model(inputs: list = None, %s - Model outputs: + **Model outputs**: The standard output template: %s diff --git a/kgcnn/literature/INorp/_make.py b/kgcnn/literature/INorp/_make.py index 54c34242..3539991c 100644 --- a/kgcnn/literature/INorp/_make.py +++ b/kgcnn/literature/INorp/_make.py @@ -97,7 +97,7 @@ def make_model(inputs: list = None, %s Args: - inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. input_embedding (dict): Deprecated in favour of input_node_embedding etc. diff --git a/kgcnn/literature/MAT/_layers.py b/kgcnn/literature/MAT/_layers.py new file mode 100644 index 00000000..eeb758ad --- /dev/null +++ b/kgcnn/literature/MAT/_layers.py @@ -0,0 +1,224 @@ +import keras as ks +from keras import ops +from typing import Union + + +class MATGlobalPool(ks.layers.Layer): + + def __init__(self, pooling_method: str = "sum", **kwargs): + super(MATGlobalPool, self).__init__(**kwargs) + self.pooling_method = pooling_method + + # Add mean with mask if required. + if self.pooling_method not in ["sum"]: + raise ValueError("`pooling_method` must be in ['sum']") + + def build(self, input_shape): + super(MATGlobalPool, self).build(input_shape) + + def call(self, inputs, mask=None, **kwargs): + r"""Forward pass. + + Args: + inputs (Tensor): Node or edge features of shape `(batch, N, ...)` . + mask (Tensor): Not used. + + Returns: + Tensor: Pooled features, e.g. summed over first axis. + """ + if self.pooling_method == "sum": + return ops.sum(inputs, axis=1) + + def get_config(self): + config = super(MATGlobalPool, self).get_config() + config.update({"pooling_method": self.pooling_method}) + return config + + +class MATDistanceMatrix(ks.layers.Layer): + + def __init__(self, trafo: Union[str, None] = "exp", **kwargs): + super(MATDistanceMatrix, self).__init__(**kwargs) + self.trafo = trafo + if self.trafo not in [None, "exp", "softmax"]: + raise ValueError("`trafo` must be in [None, 'exp', 'softmax']") + + def build(self, input_shape): + super(MATDistanceMatrix, self).build(input_shape) + + def call(self, inputs, mask=None, **kwargs): + r"""Forward pass + + Args: + inputs (Tensor): Padded Coordinates of shape `(batch, N, 3)` . + mask (Tensor): Mask of coordinates of similar shape. + + Returns: + tuple: Distance matrix of shape `(batch, N, N, 1)` plus mask. + """ + # Shape of inputs (batch, N, 3) + # Shape of mask (batch, N, 3) + diff = ops.expand_dims(inputs, axis=1) - ops.expand_dims(inputs, axis=2) + dist = ops.sum(ops.square(diff), axis=-1, keepdims=True) + # shape of dist (batch, N, N, 1) + diff_mask = ops.expand_dims(mask, axis=1) * ops.expand_dims(mask, axis=2) + dist_mask = ops.prod(diff_mask, axis=-1, keepdims=True) + + if self.trafo == "exp": + dist += ops.where( + ops.cast(dist_mask, dtype="bool"), ops.zeros_like(dist), ops.ones_like(dist) / ks.backend.epsilon()) + dist = ops.exp(-dist) + elif self.trafo == "softmax": + dist += ops.where( + ops.cast(dist_mask, dtype="bool"), ops.zeros_like(dist), -ops.ones_like(dist) / ks.backend.epsilon()) + dist = ops.nn.softmax(dist, axis=2) + + dist = dist * dist_mask + return dist, dist_mask + + def get_config(self): + config = super(MATDistanceMatrix, self).get_config() + config.update({"trafo": self.trafo}) + return config + + +class MATReduceMask(ks.layers.Layer): + + def __init__(self, axis: int, keepdims: bool, **kwargs): + super(MATReduceMask, self).__init__(**kwargs) + self.axis = axis + self.keepdims = keepdims + + def build(self, input_shape): + super(MATReduceMask, self).build(input_shape) + + def call(self, inputs, **kwargs): + r"""Forward Pass. + + Args: + inputs (Tensor): Any (mask) Tensor of sufficient rank to reduce for given axis. + + Returns: + Tensor: Product of inputs along axis. + """ + return ops.prod(inputs, keepdims=self.keepdims, axis=self.axis) + + def get_config(self): + config = super(MATReduceMask, self).get_config() + config.update({"axis": self.axis, "keepdims": self.keepdims}) + return config + + +class MATExpandMask(ks.layers.Layer): + + def __init__(self, axis: int, **kwargs): + super(MATExpandMask, self).__init__(**kwargs) + self.axis = axis + + def build(self, input_shape): + super(MATExpandMask, self).build(input_shape) + + def call(self, inputs, **kwargs): + r"""Forward Pass. + + Args: + inputs (Tensor): Any (mask) Tensor to expand given axis. + + Returns: + Tensor: Input tensor with expanded axis. + """ + return ops.expand_dims(inputs, axis=self.axis) + + def get_config(self): + config = super(MATExpandMask, self).get_config() + config.update({"axis": self.axis}) + return config + + +class MATAttentionHead(ks.layers.Layer): + + def __init__(self, units: int = 64, + lambda_distance: float = 0.3, lambda_attention: float = 0.3, + lambda_adjacency: Union[float, None] = None, add_identity: bool = False, + dropout: Union[float, None] = None, + **kwargs): + super(MATAttentionHead, self).__init__(**kwargs) + self.units = int(units) + self.add_identity = bool(add_identity) + self.lambda_distance = lambda_distance + self.lambda_attention = lambda_attention + if lambda_adjacency is not None: + self.lambda_adjacency = lambda_adjacency + else: + self.lambda_adjacency = 1.0 - self.lambda_attention - self.lambda_distance + self.scale = self.units ** -0.5 + self.dense_q = ks.layers.Dense(units=units) + self.dense_k = ks.layers.Dense(units=units) + self.dense_v = ks.layers.Dense(units=units) + self._dropout = dropout + if self._dropout is not None: + self.layer_dropout = ks.layers.Dropout(self._dropout) + + def build(self, input_shape): + super(MATAttentionHead, self).build(input_shape) + + def call(self, inputs, mask=None, **kwargs): + r"""Forward pass. + + Args: + inputs (list): List of [h_n, A_d, A_g] represented by padded :obj:`tf.Tensor` . + These are node features and adjacency matrix from distances and bonds or bond order. + mask (list): Mask tensors matching inputs, i.e. a mask tensor for each padded input. + + Returns: + Tensor: Padded node features of :math:`h_n` . + """ + h, a_d, a_g = inputs + h_mask, a_d_mask, a_g_mask = mask + q = ops.expand_dims(self.dense_q(h), axis=2) + k = ops.expand_dims(self.dense_k(h), axis=1) + v = self.dense_v(h) * h_mask + qk = q * k / self.scale + # Apply mask on self-attention + qk_mask = ops.expand_dims(h_mask, axis=1) * ops.expand_dims(h_mask, axis=2) # (b, 1, n, ...) * (b, n, 1, ...) + qk += ops.where(ops.cast(qk_mask, dtype="bool"), ops.zeros_like(qk), -ops.ones_like(qk) / ks.backend.epsilon()) + qk = ops.nn.softmax(qk, axis=2) + qk *= qk_mask + # Add diagonal to graph adjacency (optional). + if self.add_identity: + a_g_eye = ops.eye(ops.shape(a_g)[1], dtype=a_g.dtype) + a_g_eye = ops.repeat(ops.expand_dims(a_g_eye, axis=0), ops.shape(a_g)[:1], axis=0) + if a_g.shape.rank > 3: + a_g_eye = ops.expand_dims(a_g_eye, axis=-1) + a_g += a_g_eye + # Weights + qk = self.lambda_attention * qk + a_d = self.lambda_distance * ops.cast(a_d, dtype=h.dtype) + a_g = self.lambda_adjacency * ops.cast(a_g, dtype=h.dtype) + # print(qk.shape, a_d.shape, a_g.shape) + att = qk + a_d + a_g + # v has shape (b, N, F) + # att has shape (b, N, N, F) + if self._dropout is not None: + att = self.layer_dropout(att) + + # Or permute feature dimension to batch and apply on last axis via and permute back again + v = ops.transpose(v, axes=[0, 2, 1]) + att = ops.transpose(att, axes=[0, 3, 1, 2]) + hp = ops.einsum('...ij,...jk->...ik', att, ops.expand_dims(v, axis=3)) # From example in tf docs + hp = ops.squeeze(hp, axis=3) + hp = ops.transpose(hp, axes=[0, 2, 1]) + + # Same as above but may be slower. + # hp = tf.einsum('bij...,bjk...->bik...', att, tf.expand_dims(v, axis=2)) + # hp = tf.squeeze(hp, axis=2) + + hp *= h_mask + return hp + + def get_config(self): + config = super(MATAttentionHead, self).get_config() + config.update({"units": self.units, "lambda_adjacency": self.lambda_adjacency, + "lambda_attention": self.lambda_attention, "lambda_distance": self.lambda_distance, + "dropout": self._dropout, "add_identity": self.add_identity}) + return config diff --git a/kgcnn/literature/MAT/_make.py b/kgcnn/literature/MAT/_make.py new file mode 100644 index 00000000..002157d6 --- /dev/null +++ b/kgcnn/literature/MAT/_make.py @@ -0,0 +1,210 @@ +import keras as ks +from keras.backend import backend as backend_to_use +from kgcnn.layers.modules import Embedding +from kgcnn.layers.mlp import MLP +from kgcnn.models.utils import update_model_kwargs +from ._layers import MATAttentionHead, MATDistanceMatrix, MATReduceMask, MATGlobalPool, MATExpandMask + +# Keep track of model version from commit date in literature. +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-08" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'MAT' is not supported." % backend_to_use()) + +# Implementation of MAT in `tf.keras` from paper: +# Molecule Attention Transformer +# Łukasz Maziarka, Tomasz Danel, Sławomir Mucha, Krzysztof Rataj, Jacek Tabor, Stanisław Jastrzębski +# https://arxiv.org/abs/2002.08264 +# https://github.com/ardigen/MAT +# https://github.com/lucidrains/molecule-attention-transformer + + +model_default = { + "name": "MAT", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + + ], + "input_embedding": {"node": {"input_dim": 95, "output_dim": 64}, + "edge": {"input_dim": 95, "output_dim": 64}}, + "use_edge_embedding": False, + "max_atoms": None, + "distance_matrix_kwargs": {"trafo": "exp"}, + "attention_kwargs": {"units": 8, "lambda_attention": 0.3, "lambda_distance": 0.3, "lambda_adjacency": None, + "dropout": 0.1, "add_identity": False}, + "feed_forward_kwargs": {"units": [32, 32, 32], "activation": ["relu", "relu", "linear"]}, + "embedding_units": 32, + "depth": 5, + "heads": 8, + "merge_heads": "concat", + "verbose": 10, + "pooling_kwargs": {"pooling_method": "sum"}, + "output_embedding": "graph", + "output_to_tensor": True, + "output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1], + "activation": ["relu", "relu", "linear"]} +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(name: str = None, + inputs: list = None, + input_node_embedding: dict = None, + input_tensor_type: str = None, + input_edge_embedding: dict = None, + distance_matrix_kwargs: dict = None, + attention_kwargs: dict = None, + feed_forward_kwargs:dict = None, + embedding_units: int = None, + depth: int = None, + heads: int = None, + merge_heads: str = None, + verbose: int = None, + pooling_kwargs: dict = None, + output_embedding: str = None, + output_to_tensor: bool = None, + output_mlp: dict = None + ): + r"""Make `MAT `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.MAT.model_default` . + + .. note:: + + We added a linear layer to keep correct node embedding dimension. + + Inputs: + list: `[node_attributes, node_coordinates, adjacency_matrix, node_mask, adjacency_mask]` + + - node_attributes (Tensor): Node attributes of shape `(batch, N, F)` or `(batch, N)` + using an embedding layer. + - node_coordinates (Tensor): Node (atomic) coordinates of shape `(batch, N, 3)`. + - adjacency_matrix (Tensor): Edge attributes of shape `(batch, N, N, F)` or `(batch, N, N)` + using an embedding layer. + - node_mask (Tensor): Node mask of shape `(batch, N)` + - adjacency_mask (Tensor): Adjacency mask of shape `(batch, N)` + + Outputs: + Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. + + Args: + name (str): Name of the model. Should be "MAT". + inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input tensor type. Only "padded" is valid for this implementation. + input_node_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. + depth (int): Number of graph embedding units or depth of the network. + verbose (int): Level for print information. + distance_matrix_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MATDistanceMatrix`. + attention_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MATDistanceMatrix`. + feed_forward_kwargs (dict): Dictionary of layer arguments unpacked in feed forward :obj:`MLP`. + embedding_units (int): Units for node embedding. + heads (int): Number of attention heads + merge_heads (str): How to merge head, using either 'sum' or 'concat'. + pooling_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MATGlobalPool`. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`. + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_tensor_type (str): Output tensor type. Only "padded" is valid for this implementation. + + Returns: + :obj:`keras.models.Model` + """ + assert input_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation." + # Make input + node_input = ks.layers.Input(**inputs[0]) + xyz_input = ks.layers.Input(**inputs[1]) + adjacency_matrix = ks.layers.Input(**inputs[2]) + node_mask = ks.layers.Input(**inputs[3]) + adjacency_mask = ks.layers.Input(**inputs[4]) + + use_edge_embedding = input_edge_embedding is not None + use_node_embedding = input_node_embedding is not None + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(node_input) + else: + n = node_input + if use_edge_embedding: + adj = Embedding(**input_edge_embedding)(adjacency_matrix) + else: + adj = adjacency_matrix + n_mask = node_mask + adj_mask = adjacency_mask + xyz = xyz_input + + # Cast to dense Tensor with padding for MAT. + # Nodes must have feature dimension. + dist, dist_mask = MATDistanceMatrix(**distance_matrix_kwargs)(xyz, mask=n_mask) + + # Check shapes + # print(n.shape, dist.shape, adj.shape) + # print(n_mask.shape, dist_mask.shape, adj_mask.shape) + + # Adjacency is derived from edge input. If edge input has no last dimension and no embedding is used, then adjacency + # matrix will have shape (batch, max_atoms, max_atoms) and edge input should be ones or weights or bond degree. + # Otherwise, adjacency bears feature expanded from edge attributes of shape (batch, max_atoms, max_atoms, features). + has_edge_dim = len(inputs[2]["shape"]) >= 3 or len(inputs[2]["shape"]) < 3 and use_edge_embedding + + if has_edge_dim: + # Assume that feature-wise attention is not desired for adjacency, reduce to single value. + adj = ks.layers.Dense(1, use_bias=False)(adj) + adj_mask = MATReduceMask(axis=-1, keepdims=True)(adj_mask) + else: + # Make sure that shape is (batch, max_atoms, max_atoms, 1). + adj = MATExpandMask(axis=-1)(adj) + adj_mask = MATExpandMask(axis=-1)(adj_mask) + + # Repeat for depth. + h_mask = n_mask + h = ks.layers.Dense(units=embedding_units, use_bias=False)(n) # Assert correct feature dimension for skip. + for _ in range(depth): + # 1. Norm + Attention + Residual + hn = ks.layers.LayerNormalization()(h) + hs = [ + MATAttentionHead(**attention_kwargs)( + [hn, dist, adj], + mask=[n_mask, dist_mask, adj_mask] + ) + for _ in range(heads) + ] + if merge_heads in ["add", "sum", "reduce_sum"]: + hu = ks.layers.Add()(hs) + hu = ks.layers.Dense(units=embedding_units, use_bias=False)(hu) + else: + hu = ks.layers.Concatenate(axis=-1)(hs) + hu = ks.layers.Dense(units=embedding_units, use_bias=False)(hu) + h = ks.layers.Add()([h, hu]) + + # 2. Norm + MLP + Residual + hn = ks.layers.LayerNormalization()(h) + hu = MLP(**feed_forward_kwargs)(hn) + hu = ks.layers.Dense(units=embedding_units, use_bias=False)(hu) + hu = ks.layers.Multiply()([hu, h_mask]) + h = ks.layers.Add()([h, hu]) + + # pooling output + out = h + out_mask = h_mask + out = ks.layers.LayerNormalization()(out) + if output_embedding == 'graph': + out = ks.layers.Multiply()([out, out_mask]) + out = MATGlobalPool(**pooling_kwargs)(out, mask=out_mask) + # final prediction MLP for the output! + out = MLP(**output_mlp)(out) + elif output_embedding == 'node': + out = MLP(**output_mlp)(out) + out = ks.layers.Multiply()([out, out_mask]) + else: + raise ValueError("Unsupported graph embedding for mode `MAT` .") + + model = ks.models.Model( + inputs=[node_input, xyz_input, adjacency_matrix], + outputs=out, + name=name + ) + model.__kgcnn_model_version__ = __model_version__ + return model diff --git a/kgcnn/literature/MAT/_model.py b/kgcnn/literature/MAT/_model.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/MEGAN/_make.py b/kgcnn/literature/MEGAN/_make.py index e69de29b..15307975 100644 --- a/kgcnn/literature/MEGAN/_make.py +++ b/kgcnn/literature/MEGAN/_make.py @@ -0,0 +1,212 @@ +from ._model import MEGAN +from kgcnn.models.utils import update_model_kwargs +from kgcnn.layers.modules import Input +from keras.backend import backend as backend_to_use +from kgcnn.models.casting import template_cast_list_input, template_cast_output +from kgcnn.ops.activ import * + +# Keep track of model version from commit date in literature. +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-08" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'MEGAN' is not supported." % backend_to_use()) + + +# Implementation of INorp in `tf.keras` from paper: +# 'Interaction Networks for Learning about Objects, Relations and Physics' +# by Peter W. Battaglia, Razvan Pascanu, Matthew Lai, Danilo Rezende, Koray Kavukcuoglu +# http://papers.nips.cc/paper/6417-interaction-networks-for-learning-about-objects-relations-and-physics +# https://arxiv.org/abs/1612.00222 +# https://github.com/higgsfield/interaction_network_pytorch + + +model_default = { + "inputs": [ + {'shape': (None, 128), 'name': "node_attributes", 'dtype': 'float32'}, + {'shape': (None, 64), 'name': "edge_attributes", 'dtype': 'float32'}, + {'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64'}, + {'shape': (1,), 'name': "graph_labels", 'dtype': 'float32'}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + 'input_embedding': None, + "cast_disjoint_kwargs": {}, + "units": [128], + "activation": "kgcnn>leaky_relu", + "use_bias": True, + "dropout_rate": 0.0, + "use_edge_features": True, + "input_node_embedding": None, + # node/edge importance related arguments + "importance_units": [], + "importance_channels": 2, + "importance_activation": "sigmoid", # do not change + "importance_dropout_rate": 0.0, # do not change + "importance_factor": 0.0, + "importance_multiplier": 10.0, + "sparsity_factor": 0.0, + "concat_heads": True, + # mlp tail end related arguments + "final_units": [1], + "final_dropout_rate": 0.0, + "final_activation": 'linear', + "final_pooling": 'sum', + "regression_limits": None, + "regression_reference": None, + "return_importances": True, + "output_tensor_type": "padded", + 'output_embedding': 'graph', +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + name: str = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + units: list = None, + activation: str = None, + use_bias: bool = None, + dropout_rate: float = None, + use_edge_features: bool = None, + input_node_embedding: dict = None, + # node/edge importance related arguments + importance_units: list = None, + importance_channels: int = None, + importance_activation: str = None, # do not change + importance_dropout_rate: float = None, # do not change + importance_factor: float = None, + importance_multiplier: float = None, + sparsity_factor: float = None, + concat_heads: bool = None, + # mlp tail end related arguments + final_units: list = None, + final_dropout_rate: float = None, + final_activation: str = None, + final_pooling: str = None, + regression_limits: tuple = None, + regression_reference: float = None, + return_importances: bool = True, + output_embedding: dict = None, + output_tensor_type: str = None + ): + r"""Functional model definition of MEGAN. Please check documentation of :obj:`kgcnn.literature.MEGAN` . + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, graph_labels...]` + with '...' indicating mask or ID tensors following the template below. + Graph labels are used to generate explanations but not to influence model output. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + name: Name of the model. + inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + units: A list of ints where each element configures an additional attention layer. The numeric + value determines the number of hidden units to be used in the attention heads of that layer + activation: The activation function to be used within the attention layers of the network + use_bias: Whether the layers of the network should use bias weights at all + dropout_rate: The dropout rate to be applied after *each* of the attention layers of the network. + input_node_embedding: Dictionary of embedding kwargs for input embedding layer. + use_edge_features: Whether edge features should be used. Generally the network supports the + usage of edge features, but if the input data does not contain edge features, this should be + set to False. + importance_units: A list of ints where each element configures another dense layer in the + subnetwork that produces the node importance tensor from the main node embeddings. The + numeric value determines the number of hidden units in that layer. + importance_channels: The int number of explanation channels to be produced by the network. This + is the value referred to as "K". Note that this will also determine the number of attention + heads used within the attention subnetwork. + importance_factor: The weight of the explanation-only train step. If this is set to exactly + zero then the explanation train step will not be executed at all (less computationally + expensive) + importance_multiplier: An additional hyperparameter of the explanation-only train step. This + is essentially the scaling factor that is applied to the values of the dataset such that + the target values can reasonably be approximated by a sum of [0, 1] importance values. + sparsity_factor: The coefficient for the sparsity regularization of the node importance + tensor. + concat_heads: Whether to concat the heads of the attention subnetwork. The default is True. In + that case the output of each individual attention head is concatenated and the concatenated + vector is then used as the input of the next attention layer's heads. If this is False, the + vectors are average pooled instead. + final_units: A list of ints where each element configures another dense layer in the MLP + at the tail end of the network. The numeric value determines the number of the hidden units + in that layer. Note that the final element in this list has to be the same as the dimension + to be expected for the samples of the training dataset! + final_dropout_rate: The dropout rate to be applied after *every* layer of the final MLP. + final_activation: The activation to be applied at the very last layer of the MLP to produce the + actual output of the network. + final_pooling: The pooling method to be used during the global pooling phase in the network. + regression_limits: A tuple where the first value is the lower limit for the expected value range + of the regression task and teh second value the upper limit. + regression_reference: A reference value which is inside the range of expected values (best if + it was in the middle, but does not have to). Choosing different references will result + in different explanations. + return_importances: Whether the importance / explanation tensors should be returned as an output + of the model. If this is True, the output of the model will be a 3-tuple: + (output, node importances, edge importances), otherwise it is just the output itself + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + + Returns: + :obj:`keras.models.Model` + """ + model_inputs = [Input(**x) for x in inputs] + + di_inputs = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1, None], + index_assignment=[None, None, 0, None] + ) + + n, ed, disjoint_indices, gs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs + + # Wrapping disjoint model. + out = MEGAN( + units=units, + activation=activation, + use_bias=use_bias, + dropout_rate=dropout_rate, + use_edge_features=use_edge_features, + input_node_embedding=input_node_embedding, + importance_units=importance_units, + importance_channels=importance_channels, + importance_activation=importance_activation, # do not change + importance_dropout_rate=importance_dropout_rate, # do not change + importance_factor=importance_factor, + importance_multiplier=importance_multiplier, + sparsity_factor=sparsity_factor, + concat_heads=concat_heads, + final_units=final_units, + final_dropout_rate=final_dropout_rate, + final_activation=final_activation, + final_pooling=final_pooling, + regression_limits=regression_limits, + regression_reference=regression_reference, + return_importances=return_importances + )([n, ed, disjoint_indices, gs, batch_id_node, count_nodes]) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __model_version__ + + return model diff --git a/kgcnn/literature/MEGAN/_model.py b/kgcnn/literature/MEGAN/_model.py index f7e85c9f..595b6486 100644 --- a/kgcnn/literature/MEGAN/_model.py +++ b/kgcnn/literature/MEGAN/_model.py @@ -193,6 +193,9 @@ def __init__(self, if self.regression_limits is not None: self.regression_width = np.abs(self.regression_limits[1] - self.regression_limits[0]) + def build(self, input_shape): + super(MEGAN, self).build(input_shape) + def get_config(self): config = super(MEGAN, self).get_config() config.update({ @@ -218,7 +221,6 @@ def get_config(self): "return_importances": self.return_importances, "input_node_embedding": self.input_node_embedding }) - return config @property diff --git a/kgcnn/literature/MXMNet/_layers.py b/kgcnn/literature/MXMNet/_layers.py new file mode 100644 index 00000000..8246d378 --- /dev/null +++ b/kgcnn/literature/MXMNet/_layers.py @@ -0,0 +1,216 @@ +import keras as ks +from keras.layers import Layer, Add, Multiply, Concatenate, Dense +from kgcnn.literature.DimeNetPP._layers import ResidualLayer +from kgcnn.layers.mlp import GraphMLP +from kgcnn.layers.gather import GatherNodes, GatherNodesOutgoing +from kgcnn.layers.aggr import AggregateLocalEdges as PoolingLocalMessages + + +class MXMGlobalMP(Layer): + + def __init__(self, units: int = 64, **kwargs): + """Initialize layer.""" + super(MXMGlobalMP, self).__init__(**kwargs) + self.dim = units + self.h_mlp = GraphMLP(self.dim, activation="swish") + self.res1 = ResidualLayer(self.dim) + self.res2 = ResidualLayer(self.dim) + self.res3 = ResidualLayer(self.dim) + self.mlp = GraphMLP(self.dim, activation="swish") + self.add_res = Add() + + self.x_edge_mlp = GraphMLP(self.dim, activation="swish") + self.linear = Dense(self.dim, use_bias=False, activation="linear") + + self.gather = GatherNodes(split_indices=[0, 1], concat_axis=None) + self.pool = PoolingLocalMessages() + self.cat = Concatenate() + self.multiply_edge = Multiply() + self.add = Add() + + def build(self, input_shape): + """Build layer.""" + super(MXMGlobalMP, self).build(input_shape) + + def propagate(self, edge_index, x, edge_attr, **kwargs): + x_i, x_j = self.gather([x, edge_index]) + + # Prepare message. + x_edge = self.cat([x_i, x_j, edge_attr], axis=-1) + x_edge = self.x_edge_mlp(x_edge, **kwargs) + edge_attr_lin = self.linear(edge_attr, **kwargs) + x_edge = self.multiply_edge([edge_attr_lin, x_edge]) + + # Pooling here. + x_p = self.pool([x, x_edge, edge_index]) + + # Replace self loops by explicit node update here. + x_i_p = self.add([x_p, x]) + + return x_i_p + + def call(self, inputs, **kwargs): + r"""Forward pass. + + Args: + inputs: [nodes, edges, tensor_index] + + - nodes (Tensor): Node embeddings of shape `([N], F)` + - edges (Tensor): Edge or message embeddings of shape `([M], F)` + - tensor_index (Tensor): Edge indices referring to nodes of shape `(2, [M])` + + Returns: + Tensor: Node embeddings. + """ + h, edge_attr, edge_index = inputs + + # Keep for residual skip connections. + res_h = h + + # Integrate the Cross Layer Mapping inside the Global Message Passing + h = self.h_mlp(h) + + # Message Passing operation + h = self.propagate(edge_index=edge_index, x=h, edge_attr=edge_attr, **kwargs) + + # Update function f_u + h = self.res1(h) + h = self.mlp(h) + h = self.add_res([h, res_h]) + h = self.res2(h) + h = self.res3(h) + + # Message Passing operation + h = self.propagate(edge_index=edge_index, x=h, edge_attr=edge_attr, **kwargs) + + return h + + def get_config(self): + config = super(MXMGlobalMP, self).get_config() + config.update({"units": self.dim}) + return config + + +class MXMLocalMP(Layer): + + def __init__(self, units: int = 64, output_units: int = 1, activation: str = "swish", + output_kernel_initializer: str = "zeros", pooling_method: str = "sum", **kwargs): + super(MXMLocalMP, self).__init__(**kwargs) + self.dim = units + self.output_dim = output_units + self.activation = activation + self.pooling_method = pooling_method + self.h_mlp = GraphMLP(self.dim, activation=activation) + + self.mlp_kj = GraphMLP([self.dim], activation=activation) + self.mlp_ji_1 = GraphMLP([self.dim], activation=activation) + self.mlp_ji_2 = GraphMLP([self.dim], activation=activation) + self.mlp_jj = GraphMLP([self.dim], activation=activation) + + self.mlp_sbf1 = GraphMLP([self.dim, self.dim], activation=activation) + self.mlp_sbf2 = GraphMLP([self.dim, self.dim], activation=activation) + self.lin_rbf1 = Dense(self.dim, use_bias=False, activation="linear") + self.lin_rbf2 = Dense(self.dim, use_bias=False, activation="linear") + + self.res1 = ResidualLayer(self.dim) + self.res2 = ResidualLayer(self.dim) + self.res3 = ResidualLayer(self.dim) + + self.lin_rbf_out = Dense(self.dim, use_bias=False, activation="linear") + + self.h_mlp = GraphMLP(self.dim, activation=activation) + + self.y_mlp = GraphMLP([self.dim, self.dim, self.dim], activation=activation) + self.y_W = Dense(self.output_dim, activation="linear", + kernel_initializer=output_kernel_initializer) + self.add_res = Add() + + self.gather_nodes = GatherNodes(split_indices=[0, 1], concat_axis=None) + self.cat = Concatenate() + self.multiply = Multiply() + self.gather_mkj = GatherNodesOutgoing() + self.gather_mjj = GatherNodesOutgoing() + self.pool_mkj = PoolingLocalMessages(pooling_method=pooling_method) + self.pool_mjj = PoolingLocalMessages(pooling_method=pooling_method) + self.pool_h = PoolingLocalMessages(pooling_method=pooling_method) + self.add_mji_1 = Add() + self.add_mji_2 = Add() + + def call(self, inputs, **kwargs): + r"""Forward pass. + + Args: + inputs: [h, rbf, sbf1, sbf2, edge_index, angle_idx_1, angle_idx_2] + + - h (Tensor): Node embeddings of shape `([N], F)` + - rbf (Tensor): Radial basis functions of shape `([M], F)` + - sbf1 (Tensor): Spherical basis functions of shape `([K], F)` + - sbf2 (Tensor): Spherical basis functions of shape `([K], F)` + - edge_index (Tensor): Edge indices of shape `(2, [M])` + - angle_idx_1 (Tensor): Angle 1 indices of shape `(2, [K])` + - angle_idx_2 (Tensor): Angle 2 indices of shape `(2, [K])` + + Returns: + Tensor: Node embeddings. + """ + h, rbf, sbf1, sbf2, edge_index, angle_idx_1, angle_idx_2 = inputs + res_h = h + + # Integrate the Cross Layer Mapping inside the Local Message Passing + h = self.h_mlp(h, **kwargs) + + # Message Passing 1 + hi, hj = self.gather_nodes([h, edge_index]) + m = self.cat([hi, hj, rbf]) + + m_kj = self.mlp_kj(m, **kwargs) + w_rbf1 = self.lin_rbf1(rbf, **kwargs) + m_kj = self.multiply([m_kj, w_rbf1]) + m_kj = self.gather_mkj([m_kj, angle_idx_1]) + sw_sbf1 = self.mlp_sbf1(sbf1, **kwargs) + m_kj = self.multiply([m_kj, sw_sbf1]) + m_kj = self.pool_mkj([m, m_kj, angle_idx_1]) + + m_ji_1 = self.mlp_ji_1(m, **kwargs) + + m = self.add_mji_1([m_ji_1, m_kj]) + + # Message Passing 2 (index jj denotes j'i in the main paper) + m_jj = self.mlp_jj(m, **kwargs) + w_rbf2 = self.lin_rbf2(rbf, **kwargs) + m_jj = self.multiply([m_jj, w_rbf2]) + m_jj = self.gather_mjj([m_jj, angle_idx_2]) + sw_sbf2 = self.mlp_sbf2(sbf2, **kwargs) + m_jj = self.multiply([m_jj, sw_sbf2]) + m_jj = self.pool_mjj([m, m_jj, angle_idx_2]) + + m_ji_2 = self.mlp_ji_2(m, **kwargs) + + m = self.add_mji_2([m_ji_2, m_jj]) + + # Aggregation + w_rbf = self.lin_rbf_out(rbf, **kwargs) + m = self.multiply([w_rbf, m]) + h = self.pool_h([h, m, edge_index]) + + # Update function f_u + h = self.res1(h, **kwargs) + h = self.h_mlp(h, **kwargs) + h = self.add_res([h, res_h]) + h = self.res2(h, **kwargs) + h = self.res3(h, **kwargs) + + # Output Module + y = self.y_mlp(h, **kwargs) + y = self.y_W(y, **kwargs) + + return h, y + + def get_config(self): + config = super(MXMLocalMP, self).get_config() + out_conf = self.y_W.get_config() + config.update({"units": self.dim, "output_units": self.output_dim, + "activation": ks.activations.serialize(ks.activations.get(self.activation)), + "output_kernel_initializer": out_conf["kernel_initializer"], + "pooling_method": self.pooling_method}) + return config diff --git a/kgcnn/literature/MXMNet/_make.py b/kgcnn/literature/MXMNet/_make.py new file mode 100644 index 00000000..6af215c3 --- /dev/null +++ b/kgcnn/literature/MXMNet/_make.py @@ -0,0 +1,209 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from ._model import model_disjoint +from kgcnn.layers.modules import Input +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from kgcnn.models.utils import update_model_kwargs +from keras.backend import backend as backend_to_use + +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-09" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'MXMNet' is not supported." % backend_to_use()) + +# Implementation of MXMNet in `tf.keras` from paper: +# Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures +# by Shuo Zhang, Yang Liu, Lei Xie (2020) +# https://arxiv.org/abs/2011.07457 +# https://github.com/zetayue/MXMNet + + +model_default = { + "name": "MXMNet", + "inputs": [ + {"shape": (None, ), "name": "node_number", "dtype": "float32"}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, + {"shape": (None, 64), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (None, 2), "name": "range_indices", "dtype": "int64"}, + {"shape": [None, 2], "name": "angle_indices_1", "dtype": "int64"}, + {"shape": [None, 2], "name": "angle_indices_2", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"}, + {"shape": (), "name": "total_ranges", "dtype": "int64"}, + {"shape": (), "name": "total_angles_1", "dtype": "int64"}, + {"shape": (), "name": "total_angles_2", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": { + "input_dim": 95, "output_dim": 32, + "embeddings_initializer": { + "class_name": "RandomUniform", + "config": {"minval": -1.7320508075688772, "maxval": 1.7320508075688772}} + }, + "input_edge_embedding": {"input_dim": 32, "output_dim": 32}, + "bessel_basis_local": {"num_radial": 16, "cutoff": 5.0, "envelope_exponent": 5}, + "bessel_basis_global": {"num_radial": 16, "cutoff": 5.0, "envelope_exponent": 5}, # Should match range_indices + "spherical_basis_local": {"num_spherical": 7, "num_radial": 6, "cutoff": 5.0, "envelope_exponent": 5}, + "mlp_rbf_kwargs": {"units": 32, "activation": "swish"}, + "mlp_sbf_kwargs": {"units": 32, "activation": "swish"}, + "global_mp_kwargs": {"units": 32}, + "local_mp_kwargs": {"units": 32, "output_units": 1, "output_kernel_initializer": "zeros"}, + "use_edge_attributes": False, + "depth": 3, + "verbose": 10, + "node_pooling_args": {"pooling_method": "sum"}, + "output_embedding": "graph", + "use_output_mlp": True, + "output_mlp": {"use_bias": [True], "units": [1], + "activation": ["linear"]}, + "output_tensor_type": "padded", + "output_scaling": None, + "output_to_tensor": None # deprecated +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + depth: int = None, + name: str = None, + bessel_basis_local: dict = None, + bessel_basis_global: dict = None, + spherical_basis_local: dict = None, + use_edge_attributes: bool = None, + mlp_rbf_kwargs: dict = None, + mlp_sbf_kwargs: dict = None, + global_mp_kwargs: dict = None, + local_mp_kwargs: dict = None, + verbose: int = None, + output_embedding: str = None, + use_output_mlp: bool = None, + node_pooling_args: dict = None, + output_to_tensor: bool = None, + output_mlp: dict = None, + output_scaling: dict = None, + output_tensor_type: str = None, + ): + r"""Make `MXMNet `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.MXMNet.model_default` . + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are + :obj:`[nodes, coordinates, edge_attributes, edge_indices, range_indices, angle_indices_1, angle_indices_2, ...]` + with '...' indicating mask or ID tensors following the template below. + Note that you must supply angle indices as index pairs that refer to two edges or two range connections. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. + depth (int): Number of graph embedding units or depth of the network. + verbose (int): Level of verbosity. + name (str): Name of the model. + bessel_basis_local: Dictionary of layer arguments unpacked in local `:obj:BesselBasisLayer` layer. + bessel_basis_global: Dictionary of layer arguments unpacked in global `:obj:BesselBasisLayer` layer. + spherical_basis_local: Dictionary of layer arguments unpacked in `:obj:SphericalBasisLayer` layer. + use_edge_attributes: Whether to add edge attributes. Default is False. + mlp_rbf_kwargs: Dictionary of layer arguments unpacked in `:obj:MLP` layer for RBF feed-forward. + mlp_sbf_kwargs: Dictionary of layer arguments unpacked in `:obj:MLP` layer for SBF feed-forward. + global_mp_kwargs: Dictionary of layer arguments unpacked in `:obj:MXMGlobalMP` layer. + local_mp_kwargs: Dictionary of layer arguments unpacked in `:obj:MXMLocalMP` layer. + node_pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layers. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + use_output_mlp (bool): Whether to use the final output MLP. Possibility to skip final MLP. + output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor` . + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 0, 1, 1, 2, 3, 4], + index_assignment=[None, None, None, 0, 0, 3, 3] + ) + + n, x, ed, edi, rgi, adi1, adi2 = dj[:7] + batch_id_node, batch_id_edge, batch_id_ranges, batch_id_angles_1, batch_id_angles_2 = dj[7:12] + node_id, edge_id, range_id, angle_id1, angle_id2 = dj[12:17] + count_nodes, count_edges, count_ranges, count_angles1, count_angles2 = dj[17:] + + out = model_disjoint( + dj, + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[2]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=None, + input_edge_embedding=None, + bessel_basis_local=None, + spherical_basis_local=None, + bessel_basis_global=None, + use_edge_attributes=None, + mlp_rbf_kwargs=None, + mlp_sbf_kwargs=None, + depth=None, + global_mp_kwargs=None, + local_mp_kwargs=None, + node_pooling_args=None, + output_embedding=None, + use_output_mlp=None, + output_mlp=None + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/MXMNet/_model.py b/kgcnn/literature/MXMNet/_model.py new file mode 100644 index 00000000..aa5bfd12 --- /dev/null +++ b/kgcnn/literature/MXMNet/_model.py @@ -0,0 +1,86 @@ +from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle +from keras.layers import Concatenate, Subtract, Add +from kgcnn.layers.mlp import MLP, GraphMLP +from kgcnn.layers.pooling import PoolingNodes +from kgcnn.literature.DimeNetPP._layers import EmbeddingDimeBlock, SphericalBasisLayer +from ._layers import MXMGlobalMP, MXMLocalMP + + +def model_disjoint( + inputs, + use_node_embedding, + use_edge_embedding, + input_node_embedding=None, + input_edge_embedding=None, + bessel_basis_local=None, + spherical_basis_local=None, + bessel_basis_global=None, + use_edge_attributes=None, + mlp_rbf_kwargs=None, + mlp_sbf_kwargs=None, + depth=None, + global_mp_kwargs=None, + local_mp_kwargs=None, + node_pooling_args=None, + output_embedding=None, + use_output_mlp=None, + output_mlp=None +): + # Make input + n, x, ed, ei_l, ri_g, ai_1, ai_2 = inputs[:7] + batch_id_node, batch_id_edge, batch_id_ranges, batch_id_angles_1, batch_id_angles_2 = inputs[7:12] + node_id, edge_id, range_id, angle_id1, angle_id2 = inputs[12:17] + count_nodes, count_edges, count_ranges, count_angles1, count_angles2 = inputs[17:] + + # Rename to short names and make embedding, if no feature dimension. + if use_node_embedding: + n = EmbeddingDimeBlock(**input_node_embedding)(n) + if use_edge_embedding: + n = EmbeddingDimeBlock(**input_edge_embedding)(ed) + + # Calculate distances and spherical and bessel basis for local edges including angles. + # For the first version, we restrict ourselves to 2-hop angles. + pos1_l, pos2_l = NodePosition()([x, ei_l]) + d_l = NodeDistanceEuclidean()([pos1_l, pos2_l]) + rbf_l = BesselBasisLayer(**bessel_basis_local)(d_l) + v12_l = Subtract()([pos1_l, pos2_l]) + a_l_1 = EdgeAngle()([v12_l, ai_1]) + a_l_2 = EdgeAngle(vector_scale=[1.0, -1.0])([v12_l, ai_2]) + sbf_l_1 = SphericalBasisLayer(**spherical_basis_local)([d_l, a_l_1, ai_1]) + sbf_l_2 = SphericalBasisLayer(**spherical_basis_local)([d_l, a_l_2, ai_2]) + + # Calculate distance and bessel basis for global (range) edges. + pos1_g, pos2_g = NodePosition()([x, ri_g]) + d_g = NodeDistanceEuclidean()([pos1_g, pos2_g]) + rbf_g = BesselBasisLayer(**bessel_basis_global)(d_g) + + if use_edge_attributes: + rbf_l = Concatenate()([rbf_l, ed]) + + rbf_l = GraphMLP(**mlp_rbf_kwargs)([rbf_l, batch_id_edge, count_edges]) + sbf_l_1 = GraphMLP(**mlp_sbf_kwargs)([sbf_l_1, batch_id_angles_1, count_angles1]) + sbf_l_2 = GraphMLP(**mlp_sbf_kwargs)([sbf_l_2, batch_id_angles_2, count_angles2]) + rbf_g = GraphMLP(**mlp_rbf_kwargs)([rbf_g, batch_id_ranges, count_ranges]) + + # Model + h = n + nodes_list = [] + for i in range(0, depth): + h = MXMGlobalMP(**global_mp_kwargs)([h, rbf_g, ri_g]) + h, t = MXMLocalMP(**local_mp_kwargs)([h, rbf_l, sbf_l_1, sbf_l_2, ei_l, ai_1, ai_2]) + nodes_list.append(t) + + # Output embedding choice + out = Add()(nodes_list) + if output_embedding == 'graph': + out = PoolingNodes(**node_pooling_args)([count_nodes, out, batch_id_node]) + if use_output_mlp: + out = MLP(**output_mlp)(out) + elif output_embedding == 'node': + out = n + if use_output_mlp: + out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported output embedding for mode `MXMNet`") + + return out diff --git a/kgcnn/literature/MoGAT/_layers.py b/kgcnn/literature/MoGAT/_layers.py new file mode 100644 index 00000000..20eb92f0 --- /dev/null +++ b/kgcnn/literature/MoGAT/_layers.py @@ -0,0 +1,117 @@ +from keras.layers import Dense, Concatenate, Activation, Layer +from kgcnn.layers.gather import GatherNodesIngoing, GatherNodesOutgoing +from kgcnn.layers.aggr import AggregateLocalEdgesAttention + + +class AttentiveHeadFP_(Layer): + r"""Computes the attention head for `Attentive FP `__ model. + The attention coefficients are computed by :math:`a_{ij} = \sigma_1( W_1 [h_i || h_j] )`. + The initial representation :math:`h_i` and :math:`h_j` must be calculated beforehand. + The attention is obtained by :math:`\alpha_{ij} = \text{softmax}_j (a_{ij})`. + And finally pooled for context :math:`C_i = \sigma_2(\sum_j \alpha_{ij} W_2 h_j)`. + + An edge is defined by index tuple :math:`(i, j)` with the direction of the connection from :math:`j` to :math:`i`. + + Args: + units (int): Units for the linear trafo of node features before attention. + use_edge_features (bool): Append edge features to attention computation. Default is False. + activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}. + activation_context (str): Activation function for context. Default is "elu". + use_bias (bool): Use bias. Default is True. + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + + def __init__(self, + units, + use_edge_features=False, + activation='kgcnn>leaky_relu', + activation_context="elu", + use_bias=True, + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + **kwargs): + """Initialize layer.""" + super(AttentiveHeadFP_, self).__init__(**kwargs) + self.use_edge_features = use_edge_features + self.units = int(units) + self.use_bias = use_bias + kernel_args = {"kernel_regularizer": kernel_regularizer, + "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, + "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, + "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} + + self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args) + self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args) + self.lay_gather_in = GatherNodesIngoing() + self.lay_gather_out = GatherNodesOutgoing() + self.lay_concat = Concatenate(axis=-1) + + self.lay_pool_attention = AggregateLocalEdgesAttention() + self.lay_final_activ = Activation(activation=activation_context) + if use_edge_features: + self.lay_fc1 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + self.lay_fc2 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + self.lay_concat_edge = Concatenate(axis=-1) + + def build(self, input_shape): + """Build layer.""" + super(AttentiveHeadFP_, self).build(input_shape) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs (list): of [node, edges, edge_indices] + + - nodes (Tensor): Node embeddings of shape ([N], F) + - edges (Tensor): Edge or message embeddings of shape ([M], F) + - edge_indices (Tensor): Edge indices referring to nodes of shape (2, [M]) + + Returns: + Tensor: Hidden tensor of pooled edge attentions for each node. + """ + node, edge, edge_index = inputs + + if self.use_edge_features: + n_in = self.lay_gather_in([node, edge_index], **kwargs) + n_out = self.lay_gather_out([node, edge_index], **kwargs) + n_in = self.lay_fc1(n_in, **kwargs) + n_out = self.lay_concat_edge([n_out, edge], **kwargs) + n_out = self.lay_fc2(n_out, **kwargs) + else: + n_in = self.lay_gather_in([node, edge_index], **kwargs) + n_out = self.lay_gather_out([node, edge_index], **kwargs) + + wn_out = self.lay_linear_trafo(n_out, **kwargs) + e_ij = self.lay_concat([n_in, n_out], **kwargs) + e_ij = self.lay_alpha_activation(e_ij, **kwargs) # Maybe uses GAT original definition. + # a_ij = e_ij + a_ij = self.lay_alpha(e_ij, **kwargs) # Should be dimension (batch,None,1) not fully clear in original paper. + n_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs) + out = self.lay_final_activ(n_i, **kwargs) + return out + + def get_config(self): + """Update layer config.""" + config = super(AttentiveHeadFP_, self).get_config() + config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias, + "units": self.units}) + conf_sub = self.lay_alpha_activation.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: + config.update({x: conf_sub[x]}) + conf_context = self.lay_final_activ.get_config() + config.update({"activation_context": conf_context["activation"]}) + return config \ No newline at end of file diff --git a/kgcnn/literature/MoGAT/_make.py b/kgcnn/literature/MoGAT/_make.py new file mode 100644 index 00000000..87a260d3 --- /dev/null +++ b/kgcnn/literature/MoGAT/_make.py @@ -0,0 +1,160 @@ +from ._model import model_disjoint +from kgcnn.models.utils import update_model_kwargs +from kgcnn.layers.scale import get as get_scaler +from kgcnn.layers.modules import Input +from keras.backend import backend as backend_to_use +from kgcnn.models.casting import template_cast_list_input, template_cast_output +from kgcnn.ops.activ import * + +# Keep track of model version from commit date in literature. +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-10" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'MoGAT' is not supported." % backend_to_use()) + + +# Implementation of MoGAT in `keras` from paper: +# Multi‑order graph attention network for water solubility prediction and interpretation +# Sangho Lee, Hyunwoo Park, Chihyeon Choi, Wonjoon Kim, Ki Kang Kim, Young‑Kyu Han, +# Joohoon Kang, Chang‑Jong Kang & Youngdoo Son +# published March 2nd 2023 +# https://www.nature.com/articles/s41598-022-25701-5 +# https://doi.org/10.1038/s41598-022-25701-5 + + +model_default = { + "name": "MoGAT", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "cast_disjoint_kwargs": {}, + "input_embedding": None, # deprecated + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "attention_args": {"units": 32}, + "pooling_gat_nodes_args": {'pooling_method': 'mean'}, + "depthmol": 2, + "depthato": 2, + "dropout": 0.2, + "verbose": 10, + "output_embedding": "graph", + "output_to_tensor": None, # deprecated + "output_tensor_type": "padded", + "output_mlp": {"use_bias": [True], "units": [1], + "activation": ["linear"]}, + "output_scaling": None, +} + + +@update_model_kwargs(model_default) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + depthmol: int = None, + depthato: int = None, + dropout: float = None, + attention_args: dict = None, + pooling_gat_nodes_args: dict = None, + name: str = None, + verbose: int = None, # noqa + output_embedding: str = None, + output_to_tensor: bool = None, # noqa + output_mlp: dict = None, + output_scaling: dict = None, + output_tensor_type: str = None, + ): + r"""Make `MoGAT `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.MoGAT.model_default` . + + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below: + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. + depthato (int): Number of graph embedding units or depth of the network. + depthmol (int): Number of graph embedding units or depth of the graph embedding. + dropout (float): Dropout to use. + attention_args (dict): Dictionary of layer arguments unpacked in :obj:`AttentiveHeadFP` layer. Units parameter + is also used in GRU-update and :obj:`PoolingNodesAttentive`. + name (str): Name of the model. + verbose (int): Level of print output. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + di_inputs = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + mask_assignment=[0, 1, 1], + index_assignment=[None, None, 0] + ) + + n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs + + # Wrapping disjoint model. + out = model_disjoint( + [n, ed, disjoint_indices, batch_id_node, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/MoGAT/_model.py b/kgcnn/literature/MoGAT/_model.py new file mode 100644 index 00000000..3f2e3644 --- /dev/null +++ b/kgcnn/literature/MoGAT/_model.py @@ -0,0 +1,63 @@ +from kgcnn.layers.update import GRUUpdate +from kgcnn.layers.modules import Embedding +from kgcnn.layers.pooling import PoolingNodesAttentive +from keras.layers import Dense, Dropout, Concatenate, Attention +from kgcnn.layers.mlp import MLP, GraphMLP +from ._layers import AttentiveHeadFP_ + + +def model_disjoint( + inputs, + use_node_embedding: bool = None, + use_edge_embedding: bool = None, + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + attention_args=None, + dropout=None, + depthato=None, + depthmol=None, + output_embedding=None, + output_mlp=None +): + # Model implementation with disjoint representation. + n, ed, edi, batch_id_node, count_nodes = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(n) + if use_edge_embedding: + ed = Embedding(**input_edge_embedding)(ed) + + # Model + nk = Dense(units=attention_args['units'])(n) + ck = AttentiveHeadFP_(use_edge_features=True, **attention_args)([nk, ed, edi]) + nk = GRUUpdate(units=attention_args['units'])([nk, ck]) + nk = Dropout(rate=dropout)(nk) # adding dropout to the first code not in the original AttFP code ? + list_emb = [nk] # "aka r1" + for i in range(1, depthato): + ck = AttentiveHeadFP_(**attention_args)([nk, ed, edi]) + nk = GRUUpdate(units=attention_args['units'])([nk, ck]) + nk = Dropout(rate=dropout)(nk) + list_emb.append(nk) + + # we store representation of each atomic nodes (at r1,r2,...) + + if output_embedding == 'graph': + # we apply a super node to each atomic node representation and concate them + out = Concatenate()([ + # Tensor output. + PoolingNodesAttentive(units=attention_args['units'], depth=depthmol)([count_nodes, ni, batch_id_node]) for + ni in list_emb]) + # we compute the weigthed scaled self-attention of the super nodes + at = Attention(dropout=dropout, use_scale=True, score_mode="dot")([out, out]) + # we apply the dot product + out = at * out + # in the paper this is only one dense layer to the target ... very simple + out = MLP(**output_mlp)(out) + + elif output_embedding == 'node': + out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported graph embedding for mode `MoGAT` .") + + return out