Skip to content

Commit

Permalink
Update and test models.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Dec 10, 2023
1 parent 56c09e3 commit e223723
Show file tree
Hide file tree
Showing 22 changed files with 571 additions and 129 deletions.
3 changes: 3 additions & 0 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,9 @@ def compute_output_shape(self, input_shape):
return tuple([input_shape[0] + 1] + list(input_shape[1:]))
return input_shape

def compute_output_spec(self, input_spec):
return ks.KerasTensor(self.compute_output_shape(input_spec.shape), input_spec.dtype)

def call(self, inputs: list, **kwargs):
r"""Changes graph tensor from disjoint representation.
Expand Down
91 changes: 88 additions & 3 deletions kgcnn/layers/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
from typing import Union
import keras as ks
from keras import ops
from keras import ops, Layer
from keras.layers import Layer, Subtract, Multiply, Add, Subtract
from kgcnn.layers.gather import GatherNodes, GatherState
from kgcnn.layers.gather import GatherNodes, GatherState, GatherNodesOutgoing
from kgcnn.layers.polynom import spherical_bessel_jn_zeros, spherical_bessel_jn_normalization_prefactor, \
tf_spherical_bessel_jn, tf_spherical_harmonics_yl
from kgcnn.ops.axis import get_positive_axis
from kgcnn.ops.core import cross as kgcnn_cross

Expand Down Expand Up @@ -493,7 +495,7 @@ def call(self, inputs, **kwargs):
Tensor: Edge angles between edges that match the indices. Shape is `([K], 1)` .
"""
v1, v2 = self.layer_gather_vectors(inputs)
if self.vector_scale:
if self.vector_scale is not None:
v1, v2 = [
x * ops.cast(self._const_vec_scale[i], dtype=x.dtype) for i, x in enumerate([v1, v2])
]
Expand Down Expand Up @@ -1100,3 +1102,86 @@ def get_config(self):
config = super(RealToFracCoordinates, self).get_config()
config.update({"is_inverse_lattice_matrix": self.is_inverse_lattice_matrix})
return config


class SphericalBasisLayer(Layer):
r"""Expand a distance into a Bessel Basis with :math:`l=m=0`, according to
`Klicpera et al. 2020 <https://arxiv.org/abs/2011.14115>`__ .
Args:
num_spherical (int): Number of spherical basis functions
num_radial (int): Number of radial basis functions
cutoff (float): Cutoff distance c
envelope_exponent (int): Degree of the envelope to smoothen at cutoff. Default is 5.
"""

def __init__(self, num_spherical,
num_radial,
cutoff,
envelope_exponent=5,
**kwargs):
super(SphericalBasisLayer, self).__init__(**kwargs)

assert num_radial <= 64
self.num_radial = int(num_radial)
self.num_spherical = num_spherical
self.cutoff = cutoff
self.inv_cutoff = ops.convert_to_tensor(1.0 / cutoff, dtype=self.dtype)
self.envelope_exponent = envelope_exponent

# retrieve formulas
self.bessel_n_zeros = spherical_bessel_jn_zeros(num_spherical, num_radial)
self.bessel_norm = spherical_bessel_jn_normalization_prefactor(num_spherical, num_radial)

self.layer_gather_out = GatherNodesOutgoing()

def envelope(self, inputs):
p = self.envelope_exponent + 1
a = -(p + 1) * (p + 2) / 2
b = p * (p + 2)
c = -p * (p + 1) / 2
env_val = 1 / inputs + a * inputs ** (p - 1) + b * inputs ** p + c * inputs ** (p + 1)
return ops.where(inputs < 1, env_val, ops.zeros_like(inputs))

def call(self, inputs, **kwargs):
"""Forward pass.
Args:
inputs: [distance, angles, angle_index]
- distance (Tensor): Edge distance of shape ([M], 1)
- angles (Tensor): Angle list of shape ([K], 1)
- angle_index (Tensor): Angle indices referring to edges of shape (2, [K])
Returns:
Tensor: Expanded angle/distance basis. Shape is ([K], #Radial * #Spherical)
"""
edge, angles, angle_index = inputs

d = edge
d_scaled = d[:, 0] * self.inv_cutoff
rbf = []
for n in range(self.num_spherical):
for k in range(self.num_radial):
rbf += [self.bessel_norm[n, k] * tf_spherical_bessel_jn(d_scaled * self.bessel_n_zeros[n][k], n)]
rbf = ops.stack(rbf, axis=1)

d_cutoff = self.envelope(d_scaled)
rbf_env = d_cutoff[:, None] * rbf
rbf_env = self.layer_gather_out([rbf_env, angle_index], **kwargs)
# rbf_env = tf.gather(rbf_env, id_expand_kj[:, 1])

cbf = [tf_spherical_harmonics_yl(angles[:, 0], n) for n in range(self.num_spherical)]
cbf = ops.stack(cbf, axis=1)
cbf = ops.repeat(cbf, self.num_radial, axis=1)
out = rbf_env * cbf

return out

def get_config(self):
"""Update config."""
config = super(SphericalBasisLayer, self).get_config()
config.update({"num_radial": self.num_radial, "cutoff": self.cutoff,
"envelope_exponent": self.envelope_exponent, "num_spherical": self.num_spherical})
return config
18 changes: 18 additions & 0 deletions kgcnn/layers/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ def get_config(self):
return config


class SqueezeDims(ks.layers.Layer):

def __init__(self, axis, **kwargs):
super(SqueezeDims, self).__init__(**kwargs)
self.axis = axis

def build(self, input_shape):
self.built = True

def call(self, inputs):
return ops.squeeze(inputs, axis=self.axis)

def get_config(self):
config = super(SqueezeDims, self).get_config()
config.update({"axis": self.axis})
return config


def Input(
shape=None,
batch_size=None,
Expand Down
85 changes: 0 additions & 85 deletions kgcnn/literature/DimeNetPP/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from kgcnn.layers.gather import GatherNodesOutgoing
from kgcnn.layers.mlp import GraphMLP
from kgcnn.layers.update import ResidualLayer
from kgcnn.layers.polynom import spherical_bessel_jn_zeros, spherical_bessel_jn_normalization_prefactor, \
tf_spherical_bessel_jn, tf_spherical_harmonics_yl
from kgcnn.initializers.initializers import GlorotOrthogonal, HeOrthogonal
from kgcnn.ops.activ import swish

Expand Down Expand Up @@ -311,86 +309,3 @@ def get_config(self):
"embeddings_constraint": ks.constraints.serialize(self.embeddings_constraint)
})
return config


class SphericalBasisLayer(Layer):
r"""Expand a distance into a Bessel Basis with :math:`l=m=0`, according to
`Klicpera et al. 2020 <https://arxiv.org/abs/2011.14115>`__ .
Args:
num_spherical (int): Number of spherical basis functions
num_radial (int): Number of radial basis functions
cutoff (float): Cutoff distance c
envelope_exponent (int): Degree of the envelope to smoothen at cutoff. Default is 5.
"""

def __init__(self, num_spherical,
num_radial,
cutoff,
envelope_exponent=5,
**kwargs):
super(SphericalBasisLayer, self).__init__(**kwargs)

assert num_radial <= 64
self.num_radial = int(num_radial)
self.num_spherical = num_spherical
self.cutoff = cutoff
self.inv_cutoff = ops.convert_to_tensor(1.0 / cutoff, dtype=self.dtype)
self.envelope_exponent = envelope_exponent

# retrieve formulas
self.bessel_n_zeros = spherical_bessel_jn_zeros(num_spherical, num_radial)
self.bessel_norm = spherical_bessel_jn_normalization_prefactor(num_spherical, num_radial)

self.layer_gather_out = GatherNodesOutgoing()

def envelope(self, inputs):
p = self.envelope_exponent + 1
a = -(p + 1) * (p + 2) / 2
b = p * (p + 2)
c = -p * (p + 1) / 2
env_val = 1 / inputs + a * inputs ** (p - 1) + b * inputs ** p + c * inputs ** (p + 1)
return ops.where(inputs < 1, env_val, ops.zeros_like(inputs))

def call(self, inputs, **kwargs):
"""Forward pass.
Args:
inputs: [distance, angles, angle_index]
- distance (Tensor): Edge distance of shape ([M], 1)
- angles (Tensor): Angle list of shape ([K], 1)
- angle_index (Tensor): Angle indices referring to edges of shape (2, [K])
Returns:
Tensor: Expanded angle/distance basis. Shape is ([K], #Radial * #Spherical)
"""
edge, angles, angle_index = inputs

d = edge
d_scaled = d[:, 0] * self.inv_cutoff
rbf = []
for n in range(self.num_spherical):
for k in range(self.num_radial):
rbf += [self.bessel_norm[n, k] * tf_spherical_bessel_jn(d_scaled * self.bessel_n_zeros[n][k], n)]
rbf = ops.stack(rbf, axis=1)

d_cutoff = self.envelope(d_scaled)
rbf_env = d_cutoff[:, None] * rbf
rbf_env = self.layer_gather_out([rbf_env, angle_index], **kwargs)
# rbf_env = tf.gather(rbf_env, id_expand_kj[:, 1])

cbf = [tf_spherical_harmonics_yl(angles[:, 0], n) for n in range(self.num_spherical)]
cbf = ops.stack(cbf, axis=1)
cbf = ops.repeat(cbf, self.num_radial, axis=1)
out = rbf_env * cbf

return out

def get_config(self):
"""Update config."""
config = super(SphericalBasisLayer, self).get_config()
config.update({"num_radial": self.num_radial, "cutoff": self.cutoff,
"envelope_exponent": self.envelope_exponent, "num_spherical": self.num_spherical})
return config
5 changes: 3 additions & 2 deletions kgcnn/literature/DimeNetPP/_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from keras.layers import Add, Subtract, Concatenate, Dense
from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle, ShiftPeriodicLattice
from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle, ShiftPeriodicLattice, \
SphericalBasisLayer
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.mlp import MLP
from ._layers import DimNetInteractionPPBlock, EmbeddingDimeBlock, SphericalBasisLayer, DimNetOutputBlock
from ._layers import DimNetInteractionPPBlock, EmbeddingDimeBlock, DimNetOutputBlock


def model_disjoint(
Expand Down
7 changes: 7 additions & 0 deletions kgcnn/literature/INorp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._make import make_model, model_default


__all__ = [
"make_model",
"model_default"
]
3 changes: 2 additions & 1 deletion kgcnn/literature/INorp/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_edges", "dtype": "int64"}
],
'input_tensor_type': "padded",
'input_embedding': None,
"cast_disjoint_kwargs": {},
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
Expand Down Expand Up @@ -139,7 +140,7 @@ def make_model(inputs: list = None,

# Wrapping disjoint model.
out = model_disjoint(
[n, ed, disjoint_indices, gs, batch_id_node, count_nodes],
[n, ed, disjoint_indices, gs, batch_id_node, batch_id_edge, count_nodes, count_edges],
use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False,
use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False,
use_graph_embedding=("int" in inputs[3]['dtype']) if input_graph_embedding is not None else False,
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/INorp/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def model_disjoint(
uenv = Embedding(**input_graph_embedding)(uenv)

# Model
ev = GatherState(**gather_args)([uenv, n])
ev = GatherState(**gather_args)([uenv, batch_id_node])
# n-Layer Step
for i in range(0, depth):
# upd = GatherNodes()([n,edi])
Expand Down
10 changes: 10 additions & 0 deletions kgcnn/literature/MEGAN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ._model import MEGAN, shifted_sigmoid, ExplanationSparsityRegularization
from ._make import make_model


__all__ = [
"make_model",
"MEGAN",
"ExplanationSparsityRegularization",
"shifted_sigmoid"
]
3 changes: 3 additions & 0 deletions kgcnn/literature/MEGAN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


model_default = {
"name": "MEGAN",
"inputs": [
{'shape': (None, 128), 'name': "node_attributes", 'dtype': 'float32'},
{'shape': (None, 64), 'name': "edge_attributes", 'dtype': 'float32'},
Expand All @@ -32,6 +33,7 @@
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_edges", "dtype": "int64"}
],
"input_tensor_type": "padded",
'input_embedding': None,
"cast_disjoint_kwargs": {},
"units": [128],
Expand Down Expand Up @@ -72,6 +74,7 @@ def make_model(inputs: list = None,
use_bias: bool = None,
dropout_rate: float = None,
use_edge_features: bool = None,
input_embedding: dict = None, # deprecated
input_node_embedding: dict = None,
# node/edge importance related arguments
importance_units: list = None,
Expand Down
12 changes: 6 additions & 6 deletions kgcnn/literature/MEGAN/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self,
of the model. If this is True, the output of the model will be a 3-tuple:
(output, node importances, edge importances), otherwise it is just the output itself
"""
super(MEGAN, self).__init__(self, **kwargs)
super().__init__()
self.units = units
self.activation = activation
self.use_bias = use_bias
Expand Down Expand Up @@ -234,7 +234,7 @@ def call(self,

node_input, edge_input, edge_index_input, out_true, batch_id_node, count_nodes = inputs

if self.input_embedding:
if self.input_node_embedding:
node_input = self.embedding_nodes(node_input, training=training)
# First of all we apply all the graph convolutional / attention layers. Each of those layers outputs
# the attention logits alpha additional to the node embeddings. We collect all the attention logits
Expand Down Expand Up @@ -278,8 +278,8 @@ def call(self,
# the weights! We concatenate all the individual results in the end.
outs = []
for k in range(self.importance_channels):
node_importance_slice = ops.expand_dims(node_importances[:, :, k], axis=-1)
out = self.lay_pool_out(x * node_importance_slice)
node_importance_slice = ops.expand_dims(node_importances[:, k], axis=-1)
out = self.lay_pool_out([count_nodes, x * node_importance_slice, batch_id_node])

outs.append(out)

Expand Down Expand Up @@ -308,8 +308,8 @@ def call(self,
# concatenate into an output vector with K dimensions.
outs = []
for k in range(self.importance_channels):
node_importances_slice = ops.expand_dims(ni_pred[:, :, k], axis=-1)
out = self.lay_pool_out(node_importances_slice)
node_importances_slice = ops.expand_dims(ni_pred[:, k], axis=-1)
out = self.lay_pool_out([count_nodes, node_importances_slice, batch_id_node])

outs.append(out)

Expand Down
7 changes: 7 additions & 0 deletions kgcnn/literature/MXMNet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._make import make_model, model_default


__all__ = [
"make_model",
"model_default"
]
Loading

0 comments on commit e223723

Please sign in to comment.