Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 15, 2023
1 parent 15af06e commit f8de6d5
Show file tree
Hide file tree
Showing 9 changed files with 684 additions and 11 deletions.
115 changes: 115 additions & 0 deletions kgcnn/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,118 @@ def get_config(self):
'concat_heads': self.concat_heads
})
return config


class AttentiveHeadFP(Layer):
r"""Computes the attention head for `Attentive FP <https://doi.org/10.1021/acs.jmedchem.9b00959>`__ model.
The attention coefficients are computed by :math:`a_{ij} = \sigma_1( W_1 [h_i || h_j] )`.
The initial representation :math:`h_i` and :math:`h_j` must be calculated beforehand.
The attention is obtained by :math:`\alpha_{ij} = \text{softmax}_j (a_{ij})`.
And finally pooled for context :math:`C_i = \sigma_2(\sum_j \alpha_{ij} W_2 h_j)`.
An edge is defined by index tuple :math:`(i, j)` with the direction of the connection from :math:`j` to :math:`i`.
"""

def __init__(self,
units,
use_edge_features=False,
activation='kgcnn>leaky_relu',
activation_context="elu",
use_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
**kwargs):
"""Initialize layer.
Args:
units (int): Units for the linear trafo of node features before attention.
use_edge_features (bool): Append edge features to attention computation. Default is False.
activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}.
activation_context (str): Activation function for context. Default is "elu".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
activity_regularizer: Activity regularization. Default is None.
kernel_constraint: Kernel constrains. Default is None.
bias_constraint: Bias constrains. Default is None.
kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'.
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(AttentiveHeadFP, self).__init__(**kwargs)
self.use_edge_features = use_edge_features
self.units = int(units)
self.use_bias = use_bias
kernel_args = {"kernel_regularizer": kernel_regularizer,
"activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer,
"kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint,
"kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer}

self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args)
self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args)
self.lay_gather_in = GatherNodesIngoing()
self.lay_gather_out = GatherNodesOutgoing()
self.lay_concat = Concatenate(axis=-1)
self.lay_pool_attention = AggregateLocalEdgesAttention()
self.lay_final_activ = Activation(activation=activation_context)
if use_edge_features:
self.lay_fc1 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
self.lay_fc2 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
self.lay_concat_edge = Concatenate(axis=-1)

def build(self, input_shape):
"""Build layer."""
super(AttentiveHeadFP, self).build(input_shape)

def call(self, inputs, **kwargs):
r"""Forward pass.
Args:
inputs (list): [node, edges, edge_indices]
- nodes (Tensor): Node embeddings of shape ([N], F)
- edges (Tensor): Edge or message embeddings of shape ([M], F)
- edge_indices (Tensor): Edge indices referring to nodes of shape ([M], 2)
Returns:
Tensor: Hidden tensor of pooled edge attentions for each node.
"""
node, edge, edge_index = inputs

if self.use_edge_features:
n_in = self.lay_gather_in([node, edge_index], **kwargs)
n_out = self.lay_gather_out([node, edge_index], **kwargs)
n_in = self.lay_fc1(n_in, **kwargs)
n_out = self.lay_concat_edge([n_out, edge], **kwargs)
n_out = self.lay_fc2(n_out, **kwargs)
else:
n_in = self.lay_gather_in([node, edge_index], **kwargs)
n_out = self.lay_gather_out([node, edge_index], **kwargs)

wn_out = self.lay_linear_trafo(n_out, **kwargs)
e_ij = self.lay_concat([n_in, n_out], **kwargs)
e_ij = self.lay_alpha_activation(e_ij, **kwargs) # Maybe uses GAT original definition.
# a_ij = e_ij
a_ij = self.lay_alpha(e_ij, **kwargs) # Should be dimension (None, 1) not fully clear in original paper.
n_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs)
out = self.lay_final_activ(n_i, **kwargs)
return out

def get_config(self):
"""Update layer config."""
config = super(AttentiveHeadFP, self).get_config()
config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias,
"units": self.units})
conf_sub = self.lay_alpha_activation.get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "activation"]:
if x in conf_sub.keys():
config.update({x: conf_sub[x]})
conf_context = self.lay_final_activ.get_config()
config.update({"activation_context": conf_context["activation"]})
return config
192 changes: 191 additions & 1 deletion kgcnn/layers/pooling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import keras_core as ks
from keras_core.layers import Layer
from keras_core.layers import Layer, Dense, Concatenate, GRUCell, Activation
from kgcnn.layers.gather import GatherState
from keras_core import ops
from kgcnn.ops.scatter import scatter_reduce_softmax
from kgcnn.layers.aggr import Aggregate


Expand All @@ -21,3 +23,191 @@ def compute_output_shape(self, input_shape):
def call(self, inputs, **kwargs):
reference, x, idx = inputs
return self._to_aggregate([x, idx, reference])


class PoolingEmbeddingAttention(Layer):
r"""Polling all embeddings of edges or nodes per batch to obtain a graph level embedding in form of a
:obj:`Tensor` .
Uses attention for pooling. i.e. :math:`s = \sum_j \alpha_{i} n_i` .
The attention is computed via: :math:`\alpha_i = \text{softmax}_i(a_i)` from the attention
coefficients :math:`a_i` .
The attention coefficients must be computed beforehand by node or edge features or by :math:`\sigma( W [s || n_i])`
and are passed to this layer as input. Thereby this layer has no weights and only does pooling.
In summary, :math:`s = \sum_i \text{softmax}_j(a_i) n_i` is computed by the layer.
"""

def __init__(self,
softmax_method="scatter_softmax",
pooling_method="scatter_sum",
normalize_softmax: bool = False,
**kwargs):
"""Initialize layer.
Args:
normalize_softmax (bool): Whether to use normalize in softmax. Default is False.
"""
super(PoolingEmbeddingAttention, self).__init__(**kwargs)
self.normalize_softmax = normalize_softmax
self.pooling_method = pooling_method
self.softmax_method = softmax_method
self.to_aggregate = Aggregate(pooling_method=pooling_method)

def build(self, input_shape):
"""Build layer."""
assert len(input_shape) == 4
ref_shape, attr_shape, attention_shape, index_shape = [list(x) for x in input_shape]
self.to_aggregate.build([tuple(x) for x in [attr_shape, index_shape, ref_shape]])
self.built = True

def call(self, inputs, **kwargs):
r"""Forward pass.
Args:
inputs: [reference, attr, attention, batch_index]
- reference (Tensor): Reference for aggregation of shape `(batch, ...)` .
- attr (Tensor): Node or edge embeddings of shape `([N], F)` .
- attention (Tensor): Attention coefficients of shape `([N], 1)` .
- batch_index (Tensor): Batch assignment of shape `([N], )` .
Returns:
Tensor: Embedding tensor of pooled node of shape `(batch, F)` .
"""
reference, attr, attention, batch_index = inputs
shape_attention = ops.shape(reference)[:1] + ops.shape(attention)[1:]
a = scatter_reduce_softmax(batch_index, attention, shape=shape_attention, normalize=self.normalize_softmax)
x = attr * ops.broadcast_to(a, ops.shape(attr))
return self.to_aggregate([x, batch_index, reference])

def get_config(self):
"""Update layer config."""
config = super(PoolingEmbeddingAttention, self).get_config()
config.update({
"normalize_softmax": self.normalize_softmax, "pooling_method": self.pooling_method,
"softmax_method": self.softmax_method
})
return config


PoolingNodesAttention = PoolingEmbeddingAttention


class PoolingNodesAttentive(Layer):
r"""Computes the attentive pooling for node embeddings for
`Attentive FP <https://doi.org/10.1021/acs.jmedchem.9b00959>`__ model.
"""

def __init__(self,
units,
depth=3,
pooling_method="sum",
activation='kgcnn>leaky_relu',
activation_context="elu",
use_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
recurrent_activation='sigmoid',
recurrent_initializer='orthogonal',
recurrent_regularizer=None,
recurrent_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
reset_after=True,
**kwargs):
"""Initialize layer.
Args:
units (int): Units for the linear trafo of node features before attention.
pooling_method(str): Initial pooling before iteration. Default is "sum".
depth (int): Number of iterations for graph embedding. Default is 3.
activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}.
activation_context (str): Activation function for context. Default is "elu".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
activity_regularizer: Activity regularization. Default is None.
kernel_constraint: Kernel constrains. Default is None.
bias_constraint: Bias constrains. Default is None.
kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'.
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(PoolingNodesAttentive, self).__init__(**kwargs)
self.pooling_method = pooling_method
self.depth = depth
self.units = int(units)
kernel_args = {"use_bias": use_bias, "kernel_regularizer": kernel_regularizer,
"activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer,
"kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint,
"kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer}
gru_args = {"recurrent_activation": recurrent_activation,
"use_bias": use_bias, "kernel_initializer": kernel_initializer,
"recurrent_initializer": recurrent_initializer, "bias_initializer": bias_initializer,
"kernel_regularizer": kernel_regularizer, "recurrent_regularizer": recurrent_regularizer,
"bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint,
"recurrent_constraint": recurrent_constraint, "bias_constraint": bias_constraint,
"dropout": dropout, "recurrent_dropout": recurrent_dropout, "reset_after": reset_after}

self.lay_linear_trafo = Dense(units, activation="linear", **kernel_args)
self.lay_alpha = Dense(1, activation=activation, **kernel_args)
self.lay_gather_s = GatherState()
self.lay_concat = Concatenate(axis=-1)
self.lay_pool_start = PoolingNodes(pooling_method=self.pooling_method)
self.lay_pool_attention = PoolingNodesAttention()
self.lay_final_activ = Activation(activation=activation_context)
self.lay_gru = GRUCell(units=units, activation="tanh", **gru_args)

def build(self, input_shape):
"""Build layer."""
super(PoolingNodesAttentive, self).build(input_shape)

def call(self, inputs, **kwargs):
"""Forward pass.
Args:
inputs: [reference, nodes, batch_index]
- reference (Tensor): Reference for aggregation of shape `(batch, ...)` .
- nodes (Tensor): Node embeddings of shape `([N], F)` .
- batch_index (Tensor): Batch assignment of shape `([N], )` .
Returns:
Tensor: Hidden tensor of pooled node attentions of shape (batch, F).
"""
ref, node, batch_index = inputs

h = self.lay_pool_start([ref, node, batch_index], **kwargs)
wn = self.lay_linear_trafo(node, **kwargs)
for _ in range(self.depth):
hv = self.lay_gather_s([h, batch_index], **kwargs)
ev = self.lay_concat([hv, node], **kwargs)
av = self.lay_alpha(ev, **kwargs)
cont = self.lay_pool_attention([ref, wn, av, batch_index], **kwargs)
cont = self.lay_final_activ(cont, **kwargs)
h, _ = self.lay_gru(cont, h, **kwargs)

out = h
return out

def get_config(self):
"""Update layer config."""
config = super(PoolingNodesAttentive, self).get_config()
config.update({"units": self.units, "depth": self.depth, "pooling_method": self.pooling_method})
conf_sub = self.lay_alpha.get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]:
if x in conf_sub.keys():
config.update({x: conf_sub[x]})
conf_context = self.lay_final_activ.get_config()
config.update({"activation_context": conf_context["activation"]})
conf_gru = self.lay_gru.get_config()
for x in ["recurrent_activation", "recurrent_initializer", "recurrent_regularizer", "recurrent_constraint",
"dropout", "recurrent_dropout", "reset_after"]:
if x in conf_gru.keys():
config.update({x: conf_gru[x]})
return config
Loading

0 comments on commit f8de6d5

Please sign in to comment.