Skip to content

Commit

Permalink
Update docs for PoolinNodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Jan 4, 2024
1 parent 762fd88 commit d67703c
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion kgcnn/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,49 @@


class PoolingNodes(Layer):
r"""Main layer to pool node or edge attributes. Uses :obj:`Aggregate` layer."""

def __init__(self, pooling_method="scatter_sum", **kwargs):
"""Initialize layer.
Args:
pooling_method (str): Pooling method to use i.e. segment_function. Default is 'scatter_sum'.
"""
super(PoolingNodes, self).__init__(**kwargs)
self.pooling_method = pooling_method
self._to_aggregate = Aggregate(pooling_method=pooling_method)

def build(self, input_shape):
"""Build Layer."""
self._to_aggregate.build([input_shape[1], input_shape[2], input_shape[0]])
self.built = True

def compute_output_shape(self, input_shape):
"""Compute output shape."""
return self._to_aggregate.compute_output_shape([input_shape[1], input_shape[2], input_shape[0]])

def call(self, inputs, **kwargs):
r"""Forward pass.
Args:
inputs: [reference, attr, weights, batch_index]
- reference (Tensor): Reference for aggregation of shape `(batch, ...)` .
- attr (Tensor): Node or edge embeddings of shape `([N], F)` .
- batch_index (Tensor): Batch assignment of shape `([N], )` .
Returns:
Tensor: Embedding tensor of pooled node of shape `(batch, F)` .
"""
reference, x, idx = inputs
return self._to_aggregate([x, idx, reference])

def get_config(self):
"""Update layer config."""
config = super(PoolingNodes, self).get_config()
config.update({"pooling_method": self.pooling_method})
return config


class PoolingWeightedNodes(Layer):
r"""Weighted polling all embeddings of edges or nodes per batch to obtain a graph level embedding.
Expand Down Expand Up @@ -52,7 +78,7 @@ def build(self, input_shape):
self.built = True

def call(self, inputs, **kwargs):
"""Forward pass.
r"""Forward pass.
Args:
inputs: [reference, attr, weights, batch_index]
Expand Down

0 comments on commit d67703c

Please sign in to comment.