From f8de6d5d4165d88dd068aef442c35ec9df3d14e3 Mon Sep 17 00:00:00 2001 From: PatReis Date: Wed, 15 Nov 2023 13:13:05 +0100 Subject: [PATCH] update for keras 3.0 --- kgcnn/layers/attention.py | 115 +++++++++++++++ kgcnn/layers/pooling.py | 192 ++++++++++++++++++++++++- kgcnn/layers/update.py | 180 +++++++++++++++++++++++ kgcnn/literature/AttentiveFP/_make.py | 143 ++++++++++++++++++ kgcnn/literature/AttentiveFP/_model.py | 50 +++++++ kgcnn/literature/DMPNN/_make.py | 4 - kgcnn/literature/GAT/_make.py | 4 +- kgcnn/literature/PAiNN/_layers.py | 2 +- kgcnn/literature/PAiNN/_make.py | 5 +- 9 files changed, 684 insertions(+), 11 deletions(-) create mode 100644 kgcnn/layers/update.py create mode 100644 kgcnn/literature/AttentiveFP/_make.py create mode 100644 kgcnn/literature/AttentiveFP/_model.py diff --git a/kgcnn/layers/attention.py b/kgcnn/layers/attention.py index 21c5675c..d05c7f09 100644 --- a/kgcnn/layers/attention.py +++ b/kgcnn/layers/attention.py @@ -330,3 +330,118 @@ def get_config(self): 'concat_heads': self.concat_heads }) return config + + +class AttentiveHeadFP(Layer): + r"""Computes the attention head for `Attentive FP `__ model. + The attention coefficients are computed by :math:`a_{ij} = \sigma_1( W_1 [h_i || h_j] )`. + The initial representation :math:`h_i` and :math:`h_j` must be calculated beforehand. + The attention is obtained by :math:`\alpha_{ij} = \text{softmax}_j (a_{ij})`. + And finally pooled for context :math:`C_i = \sigma_2(\sum_j \alpha_{ij} W_2 h_j)`. + + An edge is defined by index tuple :math:`(i, j)` with the direction of the connection from :math:`j` to :math:`i`. + """ + + def __init__(self, + units, + use_edge_features=False, + activation='kgcnn>leaky_relu', + activation_context="elu", + use_bias=True, + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + **kwargs): + """Initialize layer. + + Args: + units (int): Units for the linear trafo of node features before attention. + use_edge_features (bool): Append edge features to attention computation. Default is False. + activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}. + activation_context (str): Activation function for context. Default is "elu". + use_bias (bool): Use bias. Default is True. + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + super(AttentiveHeadFP, self).__init__(**kwargs) + self.use_edge_features = use_edge_features + self.units = int(units) + self.use_bias = use_bias + kernel_args = {"kernel_regularizer": kernel_regularizer, + "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, + "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, + "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} + + self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args) + self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args) + self.lay_gather_in = GatherNodesIngoing() + self.lay_gather_out = GatherNodesOutgoing() + self.lay_concat = Concatenate(axis=-1) + self.lay_pool_attention = AggregateLocalEdgesAttention() + self.lay_final_activ = Activation(activation=activation_context) + if use_edge_features: + self.lay_fc1 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + self.lay_fc2 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + self.lay_concat_edge = Concatenate(axis=-1) + + def build(self, input_shape): + """Build layer.""" + super(AttentiveHeadFP, self).build(input_shape) + + def call(self, inputs, **kwargs): + r"""Forward pass. + + Args: + inputs (list): [node, edges, edge_indices] + + - nodes (Tensor): Node embeddings of shape ([N], F) + - edges (Tensor): Edge or message embeddings of shape ([M], F) + - edge_indices (Tensor): Edge indices referring to nodes of shape ([M], 2) + + Returns: + Tensor: Hidden tensor of pooled edge attentions for each node. + """ + node, edge, edge_index = inputs + + if self.use_edge_features: + n_in = self.lay_gather_in([node, edge_index], **kwargs) + n_out = self.lay_gather_out([node, edge_index], **kwargs) + n_in = self.lay_fc1(n_in, **kwargs) + n_out = self.lay_concat_edge([n_out, edge], **kwargs) + n_out = self.lay_fc2(n_out, **kwargs) + else: + n_in = self.lay_gather_in([node, edge_index], **kwargs) + n_out = self.lay_gather_out([node, edge_index], **kwargs) + + wn_out = self.lay_linear_trafo(n_out, **kwargs) + e_ij = self.lay_concat([n_in, n_out], **kwargs) + e_ij = self.lay_alpha_activation(e_ij, **kwargs) # Maybe uses GAT original definition. + # a_ij = e_ij + a_ij = self.lay_alpha(e_ij, **kwargs) # Should be dimension (None, 1) not fully clear in original paper. + n_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs) + out = self.lay_final_activ(n_i, **kwargs) + return out + + def get_config(self): + """Update layer config.""" + config = super(AttentiveHeadFP, self).get_config() + config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias, + "units": self.units}) + conf_sub = self.lay_alpha_activation.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: + if x in conf_sub.keys(): + config.update({x: conf_sub[x]}) + conf_context = self.lay_final_activ.get_config() + config.update({"activation_context": conf_context["activation"]}) + return config diff --git a/kgcnn/layers/pooling.py b/kgcnn/layers/pooling.py index 8c1921ec..cef8c5f4 100644 --- a/kgcnn/layers/pooling.py +++ b/kgcnn/layers/pooling.py @@ -1,6 +1,8 @@ import keras_core as ks -from keras_core.layers import Layer +from keras_core.layers import Layer, Dense, Concatenate, GRUCell, Activation +from kgcnn.layers.gather import GatherState from keras_core import ops +from kgcnn.ops.scatter import scatter_reduce_softmax from kgcnn.layers.aggr import Aggregate @@ -21,3 +23,191 @@ def compute_output_shape(self, input_shape): def call(self, inputs, **kwargs): reference, x, idx = inputs return self._to_aggregate([x, idx, reference]) + + +class PoolingEmbeddingAttention(Layer): + r"""Polling all embeddings of edges or nodes per batch to obtain a graph level embedding in form of a + :obj:`Tensor` . + + Uses attention for pooling. i.e. :math:`s = \sum_j \alpha_{i} n_i` . + The attention is computed via: :math:`\alpha_i = \text{softmax}_i(a_i)` from the attention + coefficients :math:`a_i` . + The attention coefficients must be computed beforehand by node or edge features or by :math:`\sigma( W [s || n_i])` + and are passed to this layer as input. Thereby this layer has no weights and only does pooling. + In summary, :math:`s = \sum_i \text{softmax}_j(a_i) n_i` is computed by the layer. + """ + + def __init__(self, + softmax_method="scatter_softmax", + pooling_method="scatter_sum", + normalize_softmax: bool = False, + **kwargs): + """Initialize layer. + + Args: + normalize_softmax (bool): Whether to use normalize in softmax. Default is False. + """ + super(PoolingEmbeddingAttention, self).__init__(**kwargs) + self.normalize_softmax = normalize_softmax + self.pooling_method = pooling_method + self.softmax_method = softmax_method + self.to_aggregate = Aggregate(pooling_method=pooling_method) + + def build(self, input_shape): + """Build layer.""" + assert len(input_shape) == 4 + ref_shape, attr_shape, attention_shape, index_shape = [list(x) for x in input_shape] + self.to_aggregate.build([tuple(x) for x in [attr_shape, index_shape, ref_shape]]) + self.built = True + + def call(self, inputs, **kwargs): + r"""Forward pass. + + Args: + inputs: [reference, attr, attention, batch_index] + + - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . + - attr (Tensor): Node or edge embeddings of shape `([N], F)` . + - attention (Tensor): Attention coefficients of shape `([N], 1)` . + - batch_index (Tensor): Batch assignment of shape `([N], )` . + + Returns: + Tensor: Embedding tensor of pooled node of shape `(batch, F)` . + """ + reference, attr, attention, batch_index = inputs + shape_attention = ops.shape(reference)[:1] + ops.shape(attention)[1:] + a = scatter_reduce_softmax(batch_index, attention, shape=shape_attention, normalize=self.normalize_softmax) + x = attr * ops.broadcast_to(a, ops.shape(attr)) + return self.to_aggregate([x, batch_index, reference]) + + def get_config(self): + """Update layer config.""" + config = super(PoolingEmbeddingAttention, self).get_config() + config.update({ + "normalize_softmax": self.normalize_softmax, "pooling_method": self.pooling_method, + "softmax_method": self.softmax_method + }) + return config + + +PoolingNodesAttention = PoolingEmbeddingAttention + + +class PoolingNodesAttentive(Layer): + r"""Computes the attentive pooling for node embeddings for + `Attentive FP `__ model. + """ + + def __init__(self, + units, + depth=3, + pooling_method="sum", + activation='kgcnn>leaky_relu', + activation_context="elu", + use_bias=True, + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + recurrent_activation='sigmoid', + recurrent_initializer='orthogonal', + recurrent_regularizer=None, + recurrent_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + reset_after=True, + **kwargs): + """Initialize layer. + + Args: + units (int): Units for the linear trafo of node features before attention. + pooling_method(str): Initial pooling before iteration. Default is "sum". + depth (int): Number of iterations for graph embedding. Default is 3. + activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}. + activation_context (str): Activation function for context. Default is "elu". + use_bias (bool): Use bias. Default is True. + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + super(PoolingNodesAttentive, self).__init__(**kwargs) + self.pooling_method = pooling_method + self.depth = depth + self.units = int(units) + kernel_args = {"use_bias": use_bias, "kernel_regularizer": kernel_regularizer, + "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, + "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, + "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} + gru_args = {"recurrent_activation": recurrent_activation, + "use_bias": use_bias, "kernel_initializer": kernel_initializer, + "recurrent_initializer": recurrent_initializer, "bias_initializer": bias_initializer, + "kernel_regularizer": kernel_regularizer, "recurrent_regularizer": recurrent_regularizer, + "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, + "recurrent_constraint": recurrent_constraint, "bias_constraint": bias_constraint, + "dropout": dropout, "recurrent_dropout": recurrent_dropout, "reset_after": reset_after} + + self.lay_linear_trafo = Dense(units, activation="linear", **kernel_args) + self.lay_alpha = Dense(1, activation=activation, **kernel_args) + self.lay_gather_s = GatherState() + self.lay_concat = Concatenate(axis=-1) + self.lay_pool_start = PoolingNodes(pooling_method=self.pooling_method) + self.lay_pool_attention = PoolingNodesAttention() + self.lay_final_activ = Activation(activation=activation_context) + self.lay_gru = GRUCell(units=units, activation="tanh", **gru_args) + + def build(self, input_shape): + """Build layer.""" + super(PoolingNodesAttentive, self).build(input_shape) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [reference, nodes, batch_index] + + - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . + - nodes (Tensor): Node embeddings of shape `([N], F)` . + - batch_index (Tensor): Batch assignment of shape `([N], )` . + + Returns: + Tensor: Hidden tensor of pooled node attentions of shape (batch, F). + """ + ref, node, batch_index = inputs + + h = self.lay_pool_start([ref, node, batch_index], **kwargs) + wn = self.lay_linear_trafo(node, **kwargs) + for _ in range(self.depth): + hv = self.lay_gather_s([h, batch_index], **kwargs) + ev = self.lay_concat([hv, node], **kwargs) + av = self.lay_alpha(ev, **kwargs) + cont = self.lay_pool_attention([ref, wn, av, batch_index], **kwargs) + cont = self.lay_final_activ(cont, **kwargs) + h, _ = self.lay_gru(cont, h, **kwargs) + + out = h + return out + + def get_config(self): + """Update layer config.""" + config = super(PoolingNodesAttentive, self).get_config() + config.update({"units": self.units, "depth": self.depth, "pooling_method": self.pooling_method}) + conf_sub = self.lay_alpha.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]: + if x in conf_sub.keys(): + config.update({x: conf_sub[x]}) + conf_context = self.lay_final_activ.get_config() + config.update({"activation_context": conf_context["activation"]}) + conf_gru = self.lay_gru.get_config() + for x in ["recurrent_activation", "recurrent_initializer", "recurrent_regularizer", "recurrent_constraint", + "dropout", "recurrent_dropout", "reset_after"]: + if x in conf_gru.keys(): + config.update({x: conf_gru[x]}) + return config diff --git a/kgcnn/layers/update.py b/kgcnn/layers/update.py new file mode 100644 index 00000000..61138c78 --- /dev/null +++ b/kgcnn/layers/update.py @@ -0,0 +1,180 @@ +import keras_core as ks +from keras_core.layers import Dense, Add, Layer + + +class GRUUpdate(Layer): + r"""Gated recurrent unit for updating node or edge embeddings. + As proposed by `NMPNN `__ . + """ + + def __init__(self, units, + activation='tanh', recurrent_activation='sigmoid', + use_bias=True, kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', kernel_regularizer=None, + recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, + recurrent_constraint=None, bias_constraint=None, dropout=0.0, + recurrent_dropout=0.0, reset_after=True, + **kwargs): + r"""Initialize layer. + + Args: + units (int): Units for GRU. + activation: Activation function to use. Default: hyperbolic tangent + (`tanh`). If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use for the recurrent step. + Default: sigmoid (`sigmoid`). If you pass `None`, no activation is + applied (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `glorot_uniform`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the recurrent state. + Default: `orthogonal`. + bias_initializer: Initializer for the bias vector. Default: `zeros`. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. Default: + `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the `recurrent_kernel` + weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. Default: + `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for + the linear transformation of the recurrent state. Default: 0. + reset_after: GRU convention (whether to apply reset gate after or + before matrix multiplication). False = "before", + True = "after" (default and CuDNN compatible). + """ + super(GRUUpdate, self).__init__(**kwargs) + self.units = units + + self.gru_cell = ks.layers.GRUCell( + units=units, + activation=activation, recurrent_activation=recurrent_activation, + use_bias=use_bias, kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, reset_after=reset_after + ) + + def build(self, input_shape): + """Build layer.""" + super(GRUUpdate, self).build(input_shape) + + def call(self, inputs, mask=None, **kwargs): + """Forward pass. + + Args: + inputs (list): [nodes, updates] + + - nodes (Tensor): Node embeddings of shape ([N], F) + - updates (Tensor): Matching node updates of shape ([N], F) + mask: Mask for inputs. Default is None. + + Returns: + Tensor: Updated nodes of shape ([N], F) + """ + n, eu = inputs + out, _ = self.gru_cell(eu, n, **kwargs) + return out + + def get_config(self): + """Update layer config.""" + config = super(GRUUpdate, self).get_config() + conf_cell = self.gru_cell.get_config() + param_list = ["units", "activation", "recurrent_activation", + "use_bias", "kernel_initializer", + "recurrent_initializer", + "bias_initializer", "kernel_regularizer", + "recurrent_regularizer", "bias_regularizer", "kernel_constraint", + "recurrent_constraint", "bias_constraint", "dropout", + "recurrent_dropout", "reset_after"] + for x in param_list: + if x in conf_cell.keys(): + config.update({x: conf_cell[x]}) + return config + + +class ResidualLayer(Layer): + r"""Residual Layer as defined by `DimNetPP `__ .""" + + def __init__(self, units, + use_bias=True, + activation='kgcnn>swish', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + **kwargs): + """Initialize layer. + + Args: + units: Dimension of the kernel. + use_bias (bool, optional): Use bias. Defaults to True. + activation (str): Activation function. Default is "kgcnn>swish". + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + super(ResidualLayer, self).__init__(**kwargs) + dense_args = { + "units": units, "activation": activation, "use_bias": use_bias, + "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, + "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, + "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, + "bias_initializer": bias_initializer + } + + self.dense_1 = Dense(**dense_args) + self.dense_2 = Dense(**dense_args) + self.add_end = Add() + + def build(self, input_shape): + """Build layer.""" + super(ResidualLayer, self).build(input_shape) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs (Tensor): Node or edge embedding of shape ([N], F) + + Returns: + Tensor: Node or edge embedding of shape ([N], F) + """ + x = self.dense_1(inputs, **kwargs) + x = self.dense_2(x, **kwargs) + x = self.add_end([inputs, x], **kwargs) + return x + + def get_config(self): + config = super(ResidualLayer, self).get_config() + conf_dense = self.dense_1.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias", "units"]: + if x in conf_dense.keys(): + config.update({x: conf_dense[x]}) + return config diff --git a/kgcnn/literature/AttentiveFP/_make.py b/kgcnn/literature/AttentiveFP/_make.py new file mode 100644 index 00000000..84e62048 --- /dev/null +++ b/kgcnn/literature/AttentiveFP/_make.py @@ -0,0 +1,143 @@ +import keras_core as ks +from kgcnn.layers.scale import get as get_scaler +from kgcnn.models.utils import update_model_kwargs +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from keras_core.backend import backend as backend_to_use +from kgcnn.layers.modules import Input +from ._model import model_disjoint + +# Keep track of model version from commit date in literature. +# To be updated if model is changed in a significant way. +__model_version__ = "2023.11.15" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] + +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'AttentiveFP' is not supported." % backend_to_use()) + +# Implementation of AttentiveFP in `keras` from paper: +# Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism +# Zhaoping Xiong, Dingyan Wang, Xiaohong Liu, Feisheng Zhong, Xiaozhe Wan, Xutong Li, Zhaojun Li, +# Xiaomin Luo, Kaixian Chen, Hualiang Jiang*, and Mingyue Zheng* +# Cite this: J. Med. Chem. 2020, 63, 16, 8749–8760 +# Publication Date:August 13, 2019 +# https://doi.org/10.1021/acs.jmedchem.9b00959 + + +model_default = { + "name": "AttentiveFP", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "cast_disjoint_kwargs": {}, + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "attention_args": {"units": 32}, + "depthmol": 2, + "depthato": 2, + "dropout": 0.1, + "verbose": 10, + "output_embedding": "graph", "output_to_tensor": True, + "output_tensor_type": "padded", + "output_mlp": {"use_bias": [True, True, False], "units": [25, 10, 1], + "activation": ["relu", "relu", "sigmoid"]} +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + cast_disjoint_kwargs: dict = None, + input_tensor_type: str = None, + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + input_embedding: dict = None, + depthmol: int = None, + depthato: int = None, + dropout: float = None, + attention_args: dict = None, + name: str = None, + verbose: int = None, + output_embedding: str = None, + output_to_tensor: bool = None, + output_tensor_type: str = None, + output_scaling: dict = None, + output_mlp: dict = None + ): + r"""Make `AttentiveFP `_ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.AttentiveFP.model_default`. + + Args: + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + cast_disjoint_kwargs (dict): Dictionary of arguments for :obj:`CastBatchedIndicesToDisjoint` . + input_tensor_type (str): Input type of graph tensor. Default is "padded". + input_embedding (dict): Deprecated in favour of `input_node_embedding` etc. + input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. + depthato (int): Number of graph embedding units or depth of the network. + depthmol (int): Number of graph embedding units or depth of the graph embedding. + dropout (float): Dropout to use. + attention_args (dict): Dictionary of layer arguments unpacked in :obj:`AttentiveHeadFP` layer. Units parameter + is also used in GRU-update and :obj:`PoolingNodesAttentive`. + name (str): Name of the model. + verbose (int): Level of print output. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + di_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs) + + n, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs + + # Wrapping disjoint model. + out = model_disjoint( + [n, ed, disjoint_indices, batch_id_node, count_nodes], + use_node_embedding=len(inputs[0]['shape']) < 2, + use_edge_embedding=len(inputs[1]['shape']) < 2, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + depthmol=depthmol, + depthato=depthato, + dropout=dropout, + attention_args=attention_args, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model diff --git a/kgcnn/literature/AttentiveFP/_model.py b/kgcnn/literature/AttentiveFP/_model.py new file mode 100644 index 00000000..fbbc5f41 --- /dev/null +++ b/kgcnn/literature/AttentiveFP/_model.py @@ -0,0 +1,50 @@ +import keras_core as ks +from kgcnn.layers.attention import AttentiveHeadFP +from kgcnn.layers.mlp import MLP, GraphMLP +from kgcnn.layers.modules import Embedding +from kgcnn.layers.update import GRUUpdate +from kgcnn.layers.pooling import PoolingNodesAttentive + + +def model_disjoint( + inputs, + use_node_embedding: bool = None, + use_edge_embedding: bool = None, + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + depthmol: int = None, + depthato: int = None, + dropout: float = None, + attention_args: dict = None, + output_embedding: str = None, + output_mlp: dict = None +): + n, ed, edi, batch_nodes, count_nodes = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(n) + if use_edge_embedding: + ed = Embedding(**input_edge_embedding)(ed) + + # Model + nk = ks.layers.Dense(units=attention_args['units'])(n) + ck = AttentiveHeadFP(use_edge_features=True, **attention_args)([nk, ed, edi]) + nk = GRUUpdate(units=attention_args['units'])([nk, ck]) + + for i in range(1, depthato): + ck = AttentiveHeadFP(**attention_args)([nk, ed, edi]) + nk = GRUUpdate(units=attention_args['units'])([nk, ck]) + nk = ks.layers.Dropout(rate=dropout)(nk) + n = nk + + # Output embedding choice + if output_embedding == 'graph': + out = PoolingNodesAttentive(units=attention_args['units'], depth=depthmol)([count_nodes, n, batch_nodes]) + out = MLP(**output_mlp)(out) + elif output_embedding == 'node': + out = GraphMLP(**output_mlp)(n) + else: + raise ValueError("Unsupported graph embedding for mode `AttentiveFP`") + + return out \ No newline at end of file diff --git a/kgcnn/literature/DMPNN/_make.py b/kgcnn/literature/DMPNN/_make.py index 8b2bc483..ccf45108 100644 --- a/kgcnn/literature/DMPNN/_make.py +++ b/kgcnn/literature/DMPNN/_make.py @@ -1,8 +1,4 @@ import keras_core as ks -from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint, - CastDisjointToBatchedGraphState, CastDisjointToBatchedAttributes, - CastBatchedGraphStateToDisjoint, CastRaggedAttributesToDisjoint, - CastRaggedIndicesToDisjoint, CastDisjointToRaggedAttributes) from kgcnn.layers.scale import get as get_scaler from kgcnn.models.utils import update_model_kwargs from kgcnn.models.casting import template_cast_output, template_cast_list_input diff --git a/kgcnn/literature/GAT/_make.py b/kgcnn/literature/GAT/_make.py index da354fc3..1e56fb3e 100644 --- a/kgcnn/literature/GAT/_make.py +++ b/kgcnn/literature/GAT/_make.py @@ -102,11 +102,11 @@ def make_model(inputs: list = None, name (str): Name of the model. verbose (int): Level of print output. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". - output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. - output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Returns: :obj:`ks.models.Model` diff --git a/kgcnn/literature/PAiNN/_layers.py b/kgcnn/literature/PAiNN/_layers.py index 881b310f..e042693b 100644 --- a/kgcnn/literature/PAiNN/_layers.py +++ b/kgcnn/literature/PAiNN/_layers.py @@ -221,7 +221,7 @@ def get_config(self): config_dense = self.lay_dense1.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]: - if x in config_dense: + if x in config_dense.keys(): config.update({x: config_dense[x]}) return config diff --git a/kgcnn/literature/PAiNN/_make.py b/kgcnn/literature/PAiNN/_make.py index df40baed..f72a7a9e 100644 --- a/kgcnn/literature/PAiNN/_make.py +++ b/kgcnn/literature/PAiNN/_make.py @@ -1,5 +1,4 @@ import keras_core as ks -from kgcnn.layers.casting import (CastBatchedIndicesToDisjoint, CastBatchedAttributesToDisjoint) from kgcnn.layers.modules import Input from kgcnn.models.utils import update_model_kwargs from keras_core.backend import backend as backend_to_use @@ -50,7 +49,7 @@ } -@update_model_kwargs(model_default, update_recursive=0) +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) def make_model(inputs: list = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, @@ -195,7 +194,7 @@ def set_scale(*args, **kwargs): } -@update_model_kwargs(model_crystal_default) +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) def make_crystal_model(inputs: list = None, input_tensor_type: str = None, input_embedding: dict = None,