From e2237230195a893a0d24472c4c03503d1f635b9f Mon Sep 17 00:00:00 2001 From: PatReis Date: Sun, 10 Dec 2023 18:58:34 +0100 Subject: [PATCH] Update and test models. --- kgcnn/layers/casting.py | 3 + kgcnn/layers/geom.py | 91 ++++++- kgcnn/layers/modules.py | 18 ++ kgcnn/literature/DimeNetPP/_layers.py | 85 ------- kgcnn/literature/DimeNetPP/_model.py | 5 +- kgcnn/literature/INorp/__init__.py | 7 + kgcnn/literature/INorp/_make.py | 3 +- kgcnn/literature/INorp/_model.py | 2 +- kgcnn/literature/MEGAN/__init__.py | 10 + kgcnn/literature/MEGAN/_make.py | 3 + kgcnn/literature/MEGAN/_model.py | 12 +- kgcnn/literature/MXMNet/__init__.py | 7 + kgcnn/literature/MXMNet/_layers.py | 16 +- kgcnn/literature/MXMNet/_make.py | 30 +-- kgcnn/literature/MXMNet/_model.py | 7 +- kgcnn/literature/MoGAT/__init__.py | 6 + kgcnn/literature/MoGAT/_make.py | 10 +- kgcnn/literature/MoGAT/_model.py | 15 +- kgcnn/literature/rGIN/__init__.py | 6 + kgcnn/literature/rGIN/_layers.py | 4 +- kgcnn/literature/rGIN/_make.py | 10 + training/hyper/hyper_esol.py | 350 ++++++++++++++++++++++++++ 22 files changed, 571 insertions(+), 129 deletions(-) diff --git a/kgcnn/layers/casting.py b/kgcnn/layers/casting.py index 6d517cbb..94e0beb3 100644 --- a/kgcnn/layers/casting.py +++ b/kgcnn/layers/casting.py @@ -515,6 +515,9 @@ def compute_output_shape(self, input_shape): return tuple([input_shape[0] + 1] + list(input_shape[1:])) return input_shape + def compute_output_spec(self, input_spec): + return ks.KerasTensor(self.compute_output_shape(input_spec.shape), input_spec.dtype) + def call(self, inputs: list, **kwargs): r"""Changes graph tensor from disjoint representation. diff --git a/kgcnn/layers/geom.py b/kgcnn/layers/geom.py index 12deb9bb..b4c03cca 100644 --- a/kgcnn/layers/geom.py +++ b/kgcnn/layers/geom.py @@ -2,9 +2,11 @@ import numpy as np from typing import Union import keras as ks -from keras import ops +from keras import ops, Layer from keras.layers import Layer, Subtract, Multiply, Add, Subtract -from kgcnn.layers.gather import GatherNodes, GatherState +from kgcnn.layers.gather import GatherNodes, GatherState, GatherNodesOutgoing +from kgcnn.layers.polynom import spherical_bessel_jn_zeros, spherical_bessel_jn_normalization_prefactor, \ + tf_spherical_bessel_jn, tf_spherical_harmonics_yl from kgcnn.ops.axis import get_positive_axis from kgcnn.ops.core import cross as kgcnn_cross @@ -493,7 +495,7 @@ def call(self, inputs, **kwargs): Tensor: Edge angles between edges that match the indices. Shape is `([K], 1)` . """ v1, v2 = self.layer_gather_vectors(inputs) - if self.vector_scale: + if self.vector_scale is not None: v1, v2 = [ x * ops.cast(self._const_vec_scale[i], dtype=x.dtype) for i, x in enumerate([v1, v2]) ] @@ -1100,3 +1102,86 @@ def get_config(self): config = super(RealToFracCoordinates, self).get_config() config.update({"is_inverse_lattice_matrix": self.is_inverse_lattice_matrix}) return config + + +class SphericalBasisLayer(Layer): + r"""Expand a distance into a Bessel Basis with :math:`l=m=0`, according to + `Klicpera et al. 2020 `__ . + + Args: + num_spherical (int): Number of spherical basis functions + num_radial (int): Number of radial basis functions + cutoff (float): Cutoff distance c + envelope_exponent (int): Degree of the envelope to smoothen at cutoff. Default is 5. + + """ + + def __init__(self, num_spherical, + num_radial, + cutoff, + envelope_exponent=5, + **kwargs): + super(SphericalBasisLayer, self).__init__(**kwargs) + + assert num_radial <= 64 + self.num_radial = int(num_radial) + self.num_spherical = num_spherical + self.cutoff = cutoff + self.inv_cutoff = ops.convert_to_tensor(1.0 / cutoff, dtype=self.dtype) + self.envelope_exponent = envelope_exponent + + # retrieve formulas + self.bessel_n_zeros = spherical_bessel_jn_zeros(num_spherical, num_radial) + self.bessel_norm = spherical_bessel_jn_normalization_prefactor(num_spherical, num_radial) + + self.layer_gather_out = GatherNodesOutgoing() + + def envelope(self, inputs): + p = self.envelope_exponent + 1 + a = -(p + 1) * (p + 2) / 2 + b = p * (p + 2) + c = -p * (p + 1) / 2 + env_val = 1 / inputs + a * inputs ** (p - 1) + b * inputs ** p + c * inputs ** (p + 1) + return ops.where(inputs < 1, env_val, ops.zeros_like(inputs)) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [distance, angles, angle_index] + + - distance (Tensor): Edge distance of shape ([M], 1) + - angles (Tensor): Angle list of shape ([K], 1) + - angle_index (Tensor): Angle indices referring to edges of shape (2, [K]) + + Returns: + Tensor: Expanded angle/distance basis. Shape is ([K], #Radial * #Spherical) + """ + edge, angles, angle_index = inputs + + d = edge + d_scaled = d[:, 0] * self.inv_cutoff + rbf = [] + for n in range(self.num_spherical): + for k in range(self.num_radial): + rbf += [self.bessel_norm[n, k] * tf_spherical_bessel_jn(d_scaled * self.bessel_n_zeros[n][k], n)] + rbf = ops.stack(rbf, axis=1) + + d_cutoff = self.envelope(d_scaled) + rbf_env = d_cutoff[:, None] * rbf + rbf_env = self.layer_gather_out([rbf_env, angle_index], **kwargs) + # rbf_env = tf.gather(rbf_env, id_expand_kj[:, 1]) + + cbf = [tf_spherical_harmonics_yl(angles[:, 0], n) for n in range(self.num_spherical)] + cbf = ops.stack(cbf, axis=1) + cbf = ops.repeat(cbf, self.num_radial, axis=1) + out = rbf_env * cbf + + return out + + def get_config(self): + """Update config.""" + config = super(SphericalBasisLayer, self).get_config() + config.update({"num_radial": self.num_radial, "cutoff": self.cutoff, + "envelope_exponent": self.envelope_exponent, "num_spherical": self.num_spherical}) + return config \ No newline at end of file diff --git a/kgcnn/layers/modules.py b/kgcnn/layers/modules.py index bff9febd..dac2c170 100644 --- a/kgcnn/layers/modules.py +++ b/kgcnn/layers/modules.py @@ -62,6 +62,24 @@ def get_config(self): return config +class SqueezeDims(ks.layers.Layer): + + def __init__(self, axis, **kwargs): + super(SqueezeDims, self).__init__(**kwargs) + self.axis = axis + + def build(self, input_shape): + self.built = True + + def call(self, inputs): + return ops.squeeze(inputs, axis=self.axis) + + def get_config(self): + config = super(SqueezeDims, self).get_config() + config.update({"axis": self.axis}) + return config + + def Input( shape=None, batch_size=None, diff --git a/kgcnn/literature/DimeNetPP/_layers.py b/kgcnn/literature/DimeNetPP/_layers.py index d4f82561..bca30254 100644 --- a/kgcnn/literature/DimeNetPP/_layers.py +++ b/kgcnn/literature/DimeNetPP/_layers.py @@ -5,8 +5,6 @@ from kgcnn.layers.gather import GatherNodesOutgoing from kgcnn.layers.mlp import GraphMLP from kgcnn.layers.update import ResidualLayer -from kgcnn.layers.polynom import spherical_bessel_jn_zeros, spherical_bessel_jn_normalization_prefactor, \ - tf_spherical_bessel_jn, tf_spherical_harmonics_yl from kgcnn.initializers.initializers import GlorotOrthogonal, HeOrthogonal from kgcnn.ops.activ import swish @@ -311,86 +309,3 @@ def get_config(self): "embeddings_constraint": ks.constraints.serialize(self.embeddings_constraint) }) return config - - -class SphericalBasisLayer(Layer): - r"""Expand a distance into a Bessel Basis with :math:`l=m=0`, according to - `Klicpera et al. 2020 `__ . - - Args: - num_spherical (int): Number of spherical basis functions - num_radial (int): Number of radial basis functions - cutoff (float): Cutoff distance c - envelope_exponent (int): Degree of the envelope to smoothen at cutoff. Default is 5. - - """ - - def __init__(self, num_spherical, - num_radial, - cutoff, - envelope_exponent=5, - **kwargs): - super(SphericalBasisLayer, self).__init__(**kwargs) - - assert num_radial <= 64 - self.num_radial = int(num_radial) - self.num_spherical = num_spherical - self.cutoff = cutoff - self.inv_cutoff = ops.convert_to_tensor(1.0 / cutoff, dtype=self.dtype) - self.envelope_exponent = envelope_exponent - - # retrieve formulas - self.bessel_n_zeros = spherical_bessel_jn_zeros(num_spherical, num_radial) - self.bessel_norm = spherical_bessel_jn_normalization_prefactor(num_spherical, num_radial) - - self.layer_gather_out = GatherNodesOutgoing() - - def envelope(self, inputs): - p = self.envelope_exponent + 1 - a = -(p + 1) * (p + 2) / 2 - b = p * (p + 2) - c = -p * (p + 1) / 2 - env_val = 1 / inputs + a * inputs ** (p - 1) + b * inputs ** p + c * inputs ** (p + 1) - return ops.where(inputs < 1, env_val, ops.zeros_like(inputs)) - - def call(self, inputs, **kwargs): - """Forward pass. - - Args: - inputs: [distance, angles, angle_index] - - - distance (Tensor): Edge distance of shape ([M], 1) - - angles (Tensor): Angle list of shape ([K], 1) - - angle_index (Tensor): Angle indices referring to edges of shape (2, [K]) - - Returns: - Tensor: Expanded angle/distance basis. Shape is ([K], #Radial * #Spherical) - """ - edge, angles, angle_index = inputs - - d = edge - d_scaled = d[:, 0] * self.inv_cutoff - rbf = [] - for n in range(self.num_spherical): - for k in range(self.num_radial): - rbf += [self.bessel_norm[n, k] * tf_spherical_bessel_jn(d_scaled * self.bessel_n_zeros[n][k], n)] - rbf = ops.stack(rbf, axis=1) - - d_cutoff = self.envelope(d_scaled) - rbf_env = d_cutoff[:, None] * rbf - rbf_env = self.layer_gather_out([rbf_env, angle_index], **kwargs) - # rbf_env = tf.gather(rbf_env, id_expand_kj[:, 1]) - - cbf = [tf_spherical_harmonics_yl(angles[:, 0], n) for n in range(self.num_spherical)] - cbf = ops.stack(cbf, axis=1) - cbf = ops.repeat(cbf, self.num_radial, axis=1) - out = rbf_env * cbf - - return out - - def get_config(self): - """Update config.""" - config = super(SphericalBasisLayer, self).get_config() - config.update({"num_radial": self.num_radial, "cutoff": self.cutoff, - "envelope_exponent": self.envelope_exponent, "num_spherical": self.num_spherical}) - return config diff --git a/kgcnn/literature/DimeNetPP/_model.py b/kgcnn/literature/DimeNetPP/_model.py index c678ee91..e4969f8a 100644 --- a/kgcnn/literature/DimeNetPP/_model.py +++ b/kgcnn/literature/DimeNetPP/_model.py @@ -1,9 +1,10 @@ from keras.layers import Add, Subtract, Concatenate, Dense -from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle, ShiftPeriodicLattice +from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle, ShiftPeriodicLattice, \ + SphericalBasisLayer from kgcnn.layers.gather import GatherNodes from kgcnn.layers.pooling import PoolingNodes from kgcnn.layers.mlp import MLP -from ._layers import DimNetInteractionPPBlock, EmbeddingDimeBlock, SphericalBasisLayer, DimNetOutputBlock +from ._layers import DimNetInteractionPPBlock, EmbeddingDimeBlock, DimNetOutputBlock def model_disjoint( diff --git a/kgcnn/literature/INorp/__init__.py b/kgcnn/literature/INorp/__init__.py index e69de29b..25427c88 100644 --- a/kgcnn/literature/INorp/__init__.py +++ b/kgcnn/literature/INorp/__init__.py @@ -0,0 +1,7 @@ +from ._make import make_model, model_default + + +__all__ = [ + "make_model", + "model_default" +] diff --git a/kgcnn/literature/INorp/_make.py b/kgcnn/literature/INorp/_make.py index 3539991c..92e194d6 100644 --- a/kgcnn/literature/INorp/_make.py +++ b/kgcnn/literature/INorp/_make.py @@ -33,6 +33,7 @@ {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} ], + 'input_tensor_type': "padded", 'input_embedding': None, "cast_disjoint_kwargs": {}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, @@ -139,7 +140,7 @@ def make_model(inputs: list = None, # Wrapping disjoint model. out = model_disjoint( - [n, ed, disjoint_indices, gs, batch_id_node, count_nodes], + [n, ed, disjoint_indices, gs, batch_id_node, batch_id_edge, count_nodes, count_edges], use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, use_graph_embedding=("int" in inputs[3]['dtype']) if input_graph_embedding is not None else False, diff --git a/kgcnn/literature/INorp/_model.py b/kgcnn/literature/INorp/_model.py index bde6df4c..ab93bf07 100644 --- a/kgcnn/literature/INorp/_model.py +++ b/kgcnn/literature/INorp/_model.py @@ -35,7 +35,7 @@ def model_disjoint( uenv = Embedding(**input_graph_embedding)(uenv) # Model - ev = GatherState(**gather_args)([uenv, n]) + ev = GatherState(**gather_args)([uenv, batch_id_node]) # n-Layer Step for i in range(0, depth): # upd = GatherNodes()([n,edi]) diff --git a/kgcnn/literature/MEGAN/__init__.py b/kgcnn/literature/MEGAN/__init__.py index e69de29b..d0a49671 100644 --- a/kgcnn/literature/MEGAN/__init__.py +++ b/kgcnn/literature/MEGAN/__init__.py @@ -0,0 +1,10 @@ +from ._model import MEGAN, shifted_sigmoid, ExplanationSparsityRegularization +from ._make import make_model + + +__all__ = [ + "make_model", + "MEGAN", + "ExplanationSparsityRegularization", + "shifted_sigmoid" +] diff --git a/kgcnn/literature/MEGAN/_make.py b/kgcnn/literature/MEGAN/_make.py index 15307975..be8ea952 100644 --- a/kgcnn/literature/MEGAN/_make.py +++ b/kgcnn/literature/MEGAN/_make.py @@ -24,6 +24,7 @@ model_default = { + "name": "MEGAN", "inputs": [ {'shape': (None, 128), 'name': "node_attributes", 'dtype': 'float32'}, {'shape': (None, 64), 'name': "edge_attributes", 'dtype': 'float32'}, @@ -32,6 +33,7 @@ {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} ], + "input_tensor_type": "padded", 'input_embedding': None, "cast_disjoint_kwargs": {}, "units": [128], @@ -72,6 +74,7 @@ def make_model(inputs: list = None, use_bias: bool = None, dropout_rate: float = None, use_edge_features: bool = None, + input_embedding: dict = None, # deprecated input_node_embedding: dict = None, # node/edge importance related arguments importance_units: list = None, diff --git a/kgcnn/literature/MEGAN/_model.py b/kgcnn/literature/MEGAN/_model.py index 595b6486..f828eb8a 100644 --- a/kgcnn/literature/MEGAN/_model.py +++ b/kgcnn/literature/MEGAN/_model.py @@ -99,7 +99,7 @@ def __init__(self, of the model. If this is True, the output of the model will be a 3-tuple: (output, node importances, edge importances), otherwise it is just the output itself """ - super(MEGAN, self).__init__(self, **kwargs) + super().__init__() self.units = units self.activation = activation self.use_bias = use_bias @@ -234,7 +234,7 @@ def call(self, node_input, edge_input, edge_index_input, out_true, batch_id_node, count_nodes = inputs - if self.input_embedding: + if self.input_node_embedding: node_input = self.embedding_nodes(node_input, training=training) # First of all we apply all the graph convolutional / attention layers. Each of those layers outputs # the attention logits alpha additional to the node embeddings. We collect all the attention logits @@ -278,8 +278,8 @@ def call(self, # the weights! We concatenate all the individual results in the end. outs = [] for k in range(self.importance_channels): - node_importance_slice = ops.expand_dims(node_importances[:, :, k], axis=-1) - out = self.lay_pool_out(x * node_importance_slice) + node_importance_slice = ops.expand_dims(node_importances[:, k], axis=-1) + out = self.lay_pool_out([count_nodes, x * node_importance_slice, batch_id_node]) outs.append(out) @@ -308,8 +308,8 @@ def call(self, # concatenate into an output vector with K dimensions. outs = [] for k in range(self.importance_channels): - node_importances_slice = ops.expand_dims(ni_pred[:, :, k], axis=-1) - out = self.lay_pool_out(node_importances_slice) + node_importances_slice = ops.expand_dims(ni_pred[:, k], axis=-1) + out = self.lay_pool_out([count_nodes, node_importances_slice, batch_id_node]) outs.append(out) diff --git a/kgcnn/literature/MXMNet/__init__.py b/kgcnn/literature/MXMNet/__init__.py index e69de29b..25427c88 100644 --- a/kgcnn/literature/MXMNet/__init__.py +++ b/kgcnn/literature/MXMNet/__init__.py @@ -0,0 +1,7 @@ +from ._make import make_model, model_default + + +__all__ = [ + "make_model", + "model_default" +] diff --git a/kgcnn/literature/MXMNet/_layers.py b/kgcnn/literature/MXMNet/_layers.py index 8246d378..e424a1d7 100644 --- a/kgcnn/literature/MXMNet/_layers.py +++ b/kgcnn/literature/MXMNet/_layers.py @@ -1,6 +1,6 @@ import keras as ks from keras.layers import Layer, Add, Multiply, Concatenate, Dense -from kgcnn.literature.DimeNetPP._layers import ResidualLayer +from kgcnn.layers.update import ResidualLayer from kgcnn.layers.mlp import GraphMLP from kgcnn.layers.gather import GatherNodes, GatherNodesOutgoing from kgcnn.layers.aggr import AggregateLocalEdges as PoolingLocalMessages @@ -8,10 +8,11 @@ class MXMGlobalMP(Layer): - def __init__(self, units: int = 64, **kwargs): + def __init__(self, units: int = 64, pooling_method="mean", **kwargs): """Initialize layer.""" super(MXMGlobalMP, self).__init__(**kwargs) self.dim = units + self.pooling_method = pooling_method self.h_mlp = GraphMLP(self.dim, activation="swish") self.res1 = ResidualLayer(self.dim) self.res2 = ResidualLayer(self.dim) @@ -23,8 +24,8 @@ def __init__(self, units: int = 64, **kwargs): self.linear = Dense(self.dim, use_bias=False, activation="linear") self.gather = GatherNodes(split_indices=[0, 1], concat_axis=None) - self.pool = PoolingLocalMessages() - self.cat = Concatenate() + self.pool = PoolingLocalMessages(pooling_method=pooling_method) + self.cat = Concatenate(axis=-1) self.multiply_edge = Multiply() self.add = Add() @@ -36,7 +37,7 @@ def propagate(self, edge_index, x, edge_attr, **kwargs): x_i, x_j = self.gather([x, edge_index]) # Prepare message. - x_edge = self.cat([x_i, x_j, edge_attr], axis=-1) + x_edge = self.cat([x_i, x_j, edge_attr], **kwargs) x_edge = self.x_edge_mlp(x_edge, **kwargs) edge_attr_lin = self.linear(edge_attr, **kwargs) x_edge = self.multiply_edge([edge_attr_lin, x_edge]) @@ -87,7 +88,7 @@ def call(self, inputs, **kwargs): def get_config(self): config = super(MXMGlobalMP, self).get_config() - config.update({"units": self.dim}) + config.update({"units": self.dim, "pooling_method": self.pooling_method}) return config @@ -135,6 +136,9 @@ def __init__(self, units: int = 64, output_units: int = 1, activation: str = "sw self.pool_h = PoolingLocalMessages(pooling_method=pooling_method) self.add_mji_1 = Add() self.add_mji_2 = Add() + + def build(self, input_shape): + super(MXMLocalMP, self).build(input_shape) def call(self, inputs, **kwargs): r"""Forward pass. diff --git a/kgcnn/literature/MXMNet/_make.py b/kgcnn/literature/MXMNet/_make.py index 6af215c3..3f1a113d 100644 --- a/kgcnn/literature/MXMNet/_make.py +++ b/kgcnn/literature/MXMNet/_make.py @@ -161,21 +161,21 @@ def make_model(inputs: list = None, dj, use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, use_edge_embedding=("int" in inputs[2]['dtype']) if input_edge_embedding is not None else False, - input_node_embedding=None, - input_edge_embedding=None, - bessel_basis_local=None, - spherical_basis_local=None, - bessel_basis_global=None, - use_edge_attributes=None, - mlp_rbf_kwargs=None, - mlp_sbf_kwargs=None, - depth=None, - global_mp_kwargs=None, - local_mp_kwargs=None, - node_pooling_args=None, - output_embedding=None, - use_output_mlp=None, - output_mlp=None + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + bessel_basis_local=bessel_basis_local, + spherical_basis_local=spherical_basis_local, + bessel_basis_global=bessel_basis_global, + use_edge_attributes=use_edge_attributes, + mlp_rbf_kwargs=mlp_rbf_kwargs, + mlp_sbf_kwargs=mlp_sbf_kwargs, + depth=depth, + global_mp_kwargs=global_mp_kwargs, + local_mp_kwargs=local_mp_kwargs, + node_pooling_args=node_pooling_args, + output_embedding=output_embedding, + use_output_mlp=use_output_mlp, + output_mlp=output_mlp ) if output_scaling is not None: diff --git a/kgcnn/literature/MXMNet/_model.py b/kgcnn/literature/MXMNet/_model.py index aa5bfd12..d25ee8e1 100644 --- a/kgcnn/literature/MXMNet/_model.py +++ b/kgcnn/literature/MXMNet/_model.py @@ -1,9 +1,10 @@ -from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle +from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle, SphericalBasisLayer from keras.layers import Concatenate, Subtract, Add from kgcnn.layers.mlp import MLP, GraphMLP from kgcnn.layers.pooling import PoolingNodes -from kgcnn.literature.DimeNetPP._layers import EmbeddingDimeBlock, SphericalBasisLayer +from kgcnn.literature.DimeNetPP._layers import EmbeddingDimeBlock from ._layers import MXMGlobalMP, MXMLocalMP +from kgcnn.ops.activ import swish def model_disjoint( @@ -36,7 +37,7 @@ def model_disjoint( if use_node_embedding: n = EmbeddingDimeBlock(**input_node_embedding)(n) if use_edge_embedding: - n = EmbeddingDimeBlock(**input_edge_embedding)(ed) + ed = EmbeddingDimeBlock(**input_edge_embedding)(ed) # Calculate distances and spherical and bessel basis for local edges including angles. # For the first version, we restrict ourselves to 2-hop angles. diff --git a/kgcnn/literature/MoGAT/__init__.py b/kgcnn/literature/MoGAT/__init__.py index e69de29b..fc290485 100644 --- a/kgcnn/literature/MoGAT/__init__.py +++ b/kgcnn/literature/MoGAT/__init__.py @@ -0,0 +1,6 @@ +from ._make import make_model, model_default + +__all__ = [ + "make_model", + "model_default" +] diff --git a/kgcnn/literature/MoGAT/_make.py b/kgcnn/literature/MoGAT/_make.py index 87a260d3..ec13f146 100644 --- a/kgcnn/literature/MoGAT/_make.py +++ b/kgcnn/literature/MoGAT/_make.py @@ -15,7 +15,6 @@ if backend_to_use() not in __kgcnn_model_backend_supported__: raise NotImplementedError("Backend '%s' for model 'MoGAT' is not supported." % backend_to_use()) - # Implementation of MoGAT in `keras` from paper: # Multi‑order graph attention network for water solubility prediction and interpretation # Sangho Lee, Hyunwoo Park, Chihyeon Choi, Wonjoon Kim, Ki Kang Kim, Young‑Kyu Han, @@ -132,6 +131,15 @@ def make_model(inputs: list = None, [n, ed, disjoint_indices, batch_id_node, count_nodes], use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + attention_args=attention_args, + dropout=dropout, + depthato=depthato, + depthmol=depthmol, + output_embedding=output_embedding, + output_mlp=output_mlp, + pooling_gat_nodes_args=pooling_gat_nodes_args ) if output_scaling is not None: diff --git a/kgcnn/literature/MoGAT/_model.py b/kgcnn/literature/MoGAT/_model.py index 3f2e3644..225f5e05 100644 --- a/kgcnn/literature/MoGAT/_model.py +++ b/kgcnn/literature/MoGAT/_model.py @@ -1,5 +1,6 @@ from kgcnn.layers.update import GRUUpdate -from kgcnn.layers.modules import Embedding +from keras.layers import Flatten, Add +from kgcnn.layers.modules import Embedding, ExpandDims, SqueezeDims from kgcnn.layers.pooling import PoolingNodesAttentive from keras.layers import Dense, Dropout, Concatenate, Attention from kgcnn.layers.mlp import MLP, GraphMLP @@ -17,7 +18,8 @@ def model_disjoint( depthato=None, depthmol=None, output_embedding=None, - output_mlp=None + output_mlp=None, + pooling_gat_nodes_args=None ): # Model implementation with disjoint representation. n, ed, edi, batch_id_node, count_nodes = inputs @@ -44,18 +46,23 @@ def model_disjoint( if output_embedding == 'graph': # we apply a super node to each atomic node representation and concate them - out = Concatenate()([ + out = [ # Tensor output. PoolingNodesAttentive(units=attention_args['units'], depth=depthmol)([count_nodes, ni, batch_id_node]) for - ni in list_emb]) + ni in list_emb + ] + out = [ExpandDims(axis=1)(x) for x in out] + out = Concatenate(axis=1)(out) # we compute the weigthed scaled self-attention of the super nodes at = Attention(dropout=dropout, use_scale=True, score_mode="dot")([out, out]) # we apply the dot product out = at * out + out = Flatten()(out) # in the paper this is only one dense layer to the target ... very simple out = MLP(**output_mlp)(out) elif output_embedding == 'node': + n = Add()(list_emb) out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) else: raise ValueError("Unsupported graph embedding for mode `MoGAT` .") diff --git a/kgcnn/literature/rGIN/__init__.py b/kgcnn/literature/rGIN/__init__.py index e69de29b..e8cd6419 100644 --- a/kgcnn/literature/rGIN/__init__.py +++ b/kgcnn/literature/rGIN/__init__.py @@ -0,0 +1,6 @@ +from ._make import make_model, model_default + +__all__ = [ + "make_model", + "model_default", +] diff --git a/kgcnn/literature/rGIN/_layers.py b/kgcnn/literature/rGIN/_layers.py index f27ea686..2d601758 100644 --- a/kgcnn/literature/rGIN/_layers.py +++ b/kgcnn/literature/rGIN/_layers.py @@ -71,8 +71,8 @@ def call(self, inputs, **kwargs): node, edge_index = inputs node_shape = ops.shape(node) random_values = ops.cast( - ks.random.uniform([node_shape[0], 1], maxval=self.random_range, dtype="int32"), - self.dtype) / self.random_range, + ks.random.uniform([node_shape[0], 1], maxval=self.random_range, dtype="float32"), + self.dtype) / self.random_range node = self.lay_concat([node, random_values]) diff --git a/kgcnn/literature/rGIN/_make.py b/kgcnn/literature/rGIN/_make.py index 2bbc62a3..8b7e9833 100644 --- a/kgcnn/literature/rGIN/_make.py +++ b/kgcnn/literature/rGIN/_make.py @@ -68,7 +68,17 @@ def make_model(inputs: list = None, r"""Make `rGIN `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.rGIN.model_default` . + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below: + %s + + **Model outputs**: + The standard output template: + + %s Args: inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. diff --git a/training/hyper/hyper_esol.py b/training/hyper/hyper_esol.py index 8ea2ecc0..5839812d 100644 --- a/training/hyper/hyper_esol.py +++ b/training/hyper/hyper_esol.py @@ -1236,4 +1236,354 @@ "kgcnn_version": "4.0.0" } }, + "MoGAT": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.MoGAT", + "config": { + "name": "MoGAT", + "inputs": [ + {"shape": [None, 41], "name": "node_attributes", "dtype": "float32"}, + {"shape": [None, 11], "name": "edge_attributes", "dtype": "float32"}, + {"shape": [None, 2], "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"}, + ], + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "attention_args": {"units": 100}, + "depthato": 2, "depthmol": 2, + "pooling_gat_nodes_args": {'pooling_method': 'mean'}, + "dropout": 0.2, + "verbose": 10, + "output_embedding": "graph", + "output_mlp": {"use_bias": [True], "units": [1], + "activation": ["linear"]} + } + }, + "training": { + "fit": { + "batch_size": 200, "epochs": 200, "validation_freq": 1, "verbose": 2, + "callbacks": [] + }, + "compile": { + "optimizer": {"class_name": "AdamW", + "config": {"learning_rate": 0.001, "weight_decay": 1e-05}}, + "loss": "mean_squared_error" + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}} + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"set_attributes": {}}, + {"map_list": {"method": "count_nodes_and_edges"}} + ] + }, + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, + "INorp": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.INorp", + "config": { + "name": "INorp", + "inputs": [ + {"shape": [None, 41], "name": "node_attributes", "dtype": "float32"}, + {"shape": [None, 11], "name": "edge_attributes", "dtype": "float32"}, + {"shape": [None, 2], "name": "edge_indices", "dtype": "int64"}, + {"shape": [], "name": "graph_size", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"}, + ], + 'input_tensor_type': "padded", + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 32}, + "input_edge_embedding": {"input_dim": 15, "output_dim": 32}, + "input_graph_embedding": {"input_dim": 100, "output_dim": 32}, + "set2set_args": {"channels": 32, "T": 3, "pooling_method": "mean", "init_qstar": "mean"}, + "node_mlp_args": {"units": [32, 32], "use_bias": True, "activation": ["relu", "linear"]}, + "edge_mlp_args": {"units": [32, 32], "activation": ["relu", "linear"]}, + "pooling_args": {"pooling_method": "sum"}, + "depth": 3, "use_set2set": False, "verbose": 10, + "gather_args": {}, + "output_embedding": "graph", + "output_mlp": {"use_bias": [True, True, False], "units": [32, 32, 1], + "activation": ["relu", "relu", "linear"]}, + } + }, + "training": { + "fit": { + "batch_size": 32, "epochs": 500, "validation_freq": 2, "verbose": 2, + "callbacks": [ + {"class_name": "kgcnn>LinearLearningRateScheduler", "config": { + "learning_rate_start": 0.5e-03, "learning_rate_stop": 1e-05, "epo_min": 300, "epo": 500, + "verbose": 0 + } + } + ] + }, + "compile": { + "optimizer": {"class_name": "Adam", "config": {"learning_rate": 5e-03}}, + "loss": "mean_absolute_error" + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}} + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"set_attributes": {}}, + {"map_list": {"method": "count_nodes_and_edges"}} + ] + }, + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, + "MEGAN": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.MEGAN", + "config": { + 'name': "MEGAN", + 'inputs': [ + {'shape': (None, 41), 'name': "node_attributes", 'dtype': 'float32'}, + {'shape': (None, ), 'name': "edge_number", 'dtype': 'float32'}, + {'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64'}, + {"shape": [2], "name": "graph_attributes", "dtype": "float32"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"}, + ], + "input_tensor_type": "padded", + 'units': [60, 50, 40, 30], + 'importance_units': [], + 'final_units': [50, 30, 10, 1], + 'dropout_rate': 0.3, + 'final_dropout_rate': 0.00, + 'importance_channels': 3, + 'return_importances': False, + 'use_edge_features': False, + } + }, + "training": { + "fit": { + "batch_size": 64, + "epochs": 400, + "validation_freq": 10, + "verbose": 2, + "callbacks": [ + { + "class_name": "kgcnn>LinearLearningRateScheduler", "config": { + "learning_rate_start": 1e-03, "learning_rate_stop": 1e-05, "epo_min": 200, "epo": 400, + "verbose": 0 + } + } + ] + }, + "compile": { + "optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}}, + "loss": "mean_squared_error" + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}} + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"set_attributes": {}}, + {"map_list": {"method": "set_range", "max_distance": 3, "max_neighbours": 100}}, + {"map_list": {"method": "count_nodes_and_edges"}} + ] + }, + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, + "rGIN": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.rGIN", + "config": { + "name": "rGIN", + "inputs": [ + {"shape": [None, 41], "name": "node_attributes", "dtype": "float32"}, + {"shape": [None, 2], "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "input_embedding": None, + "input_node_embedding": {"input_dim": 96, "output_dim": 95}, + "depth": 5, + "dropout": 0.05, + "gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"], + "use_normalization": True, "normalization_technique": "graph"}, + "rgin_args": {"random_range": 100}, + "last_mlp": {"use_bias": True, "units": [64, 32, 1], "activation": ["relu", "relu", "linear"]}, + "output_embedding": "graph", + "output_mlp": {"activation": "linear", "units": 1} + } + }, + "training": { + "fit": {"batch_size": 32, "epochs": 300, "validation_freq": 1, "verbose": 2, "callbacks": [] + }, + "compile": { + "optimizer": {"class_name": "Adam", + "config": {"learning_rate": { + "module": "keras.optimizers.schedules", + "class_name": "ExponentialDecay", + "config": {"initial_learning_rate": 0.001, + "decay_steps": 1600, + "decay_rate": 0.5, "staircase": False}}} + }, + "loss": "mean_absolute_error" + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}} + }, + "data": { + "dataset": { + "class_name": "ESOLDataset", + "module_name": "kgcnn.data.datasets.ESOLDataset", + "config": {}, + "methods": [ + {"set_attributes": {}}, + {"map_list": {"method": "count_nodes_and_edges"}} + ] + }, + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, + "MXMNet": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.MXMNet", + "config": { + "name": "MXMNet", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, + {"shape": (None, 1), "name": "edge_weights", "dtype": "float32"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (None, 2), "name": "range_indices", "dtype": "int64"}, + {"shape": [None, 2], "name": "angle_indices_1", "dtype": "int64"}, + {"shape": [None, 2], "name": "angle_indices_2", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"}, + {"shape": (), "name": "total_ranges", "dtype": "int64"}, + {"shape": (), "name": "total_angles_1", "dtype": "int64"}, + {"shape": (), "name": "total_angles_2", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 32}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 32}, + "bessel_basis_local": {"num_radial": 16, "cutoff": 5.0, "envelope_exponent": 5}, + "bessel_basis_global": {"num_radial": 16, "cutoff": 5.0, "envelope_exponent": 5}, + "spherical_basis_local": {"num_spherical": 7, "num_radial": 6, "cutoff": 5.0, "envelope_exponent": 5}, + "mlp_rbf_kwargs": {"units": 32, "activation": "swish"}, + "mlp_sbf_kwargs": {"units": 32, "activation": "swish"}, + "global_mp_kwargs": {"units": 32, "pooling_method": "mean"}, + "local_mp_kwargs": {"units": 32, "output_units": 1, + "output_kernel_initializer": "glorot_uniform"}, + "use_edge_attributes": False, + "depth": 4, + "verbose": 10, + "node_pooling_args": {"pooling_method": "sum"}, + "output_embedding": "graph", "output_to_tensor": True, + "use_output_mlp": False, + "output_mlp": {"use_bias": [True], "units": [1], + "activation": ["linear"]} + } + }, + "training": { + "fit": { + "batch_size": 128, "epochs": 900, "validation_freq": 10, "verbose": 2, + "callbacks": [ + {"class_name": "kgcnn>LinearWarmupExponentialLRScheduler", "config": { + "lr_start": 1e-03, "gamma": 0.9961697, "epo_warmup": 1, "verbose": 1, "steps_per_epoch": 45}} + ] + }, + "compile": { + "optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03, "global_clipnorm": 1000}}, + "loss": "mean_absolute_error", + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardLabelScaler", + "config": {"with_std": True, "with_mean": True, "copy": True}} + }, + "data": { + "dataset": { + "class_name": "QM7Dataset", + "module_name": "kgcnn.data.datasets.QM7Dataset", + "config": {}, + "methods": [ + {"map_list": {"method": "set_edge_weights_uniform"}}, + {"map_list": {"method": "set_range", "max_distance": 5, "max_neighbours": 1000}}, + {"map_list": {"method": "set_angle", "range_indices": "edge_indices", "edge_pairing": "jk", + "angle_indices": "angle_indices_1", + "angle_indices_nodes": "angle_indices_nodes_1", + "angle_attributes": "angle_attributes_1"}}, + {"map_list": {"method": "set_angle", "range_indices": "edge_indices", "edge_pairing": "ik", + "allow_self_edges": True, + "angle_indices": "angle_indices_2", + "angle_indices_nodes": "angle_indices_nodes_2", + "angle_attributes": "angle_attributes_2"}}, + {"map_list": {"method": "count_nodes_and_edges"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_angles_1", + "count_edges": "angle_indices_1"}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_angles_2", + "count_edges": "angle_indices_2"}} + ] + }, + "data_unit": "kcal/mol" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, }