Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ICML 2023: Initial commit of Simplicial Neural Networks #98

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/nn/simplicial/test_snn_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Test the SNN layer."""

import torch

from topomodelx.nn.simplicial.snn_layer import SNNLayer


class TestSNNLayer:
"""Test the SNN layer."""

def test_forward(self):
"""Test the forward pass of the HSN layer."""
in_channels = 5
out_channels = 5
n_nodes = 10
K = 5
lapl_0 = torch.randint(0, 2, (n_nodes, n_nodes)).float()

x_0 = torch.randn(n_nodes, in_channels)

snn = SNNLayer(in_channels, out_channels, K)
output = snn.forward(x_0, lapl_0)

assert output.shape == (n_nodes, out_channels)

def test_reset_parameters(self):
"""Test the reset of the parameters."""
in_channels = 5
out_channels = 5
K = 5

snn = SNNLayer(in_channels, out_channels, K)
snn.reset_parameters()

for module in snn.modules():
if isinstance(module, torch.nn.Conv2d):
torch.testing.assert_allclose(
module.weight, torch.zeros_like(module.weight)
)
torch.testing.assert_allclose(
module.bias, torch.zeros_like(module.bias)
)
104 changes: 104 additions & 0 deletions topomodelx/nn/simplicial/snn_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Simplicial Neural Network Layer."""
import torch

from topomodelx.base.aggregation import Aggregation
from topomodelx.base.conv import Conv


class SNNLayer(torch.nn.Module):
"""Layer of a Simplicial Neural Network (SNN).

Implementation of the SNN layer proposed in [SNN20].


References
----------
.. [SNN20] Stefania Ebli, Michael Defferrard and Gard Spreemann.
Simplicial Neural Networks.
Topological Data Analysis and Beyond workshop at NeurIPS.
https://arxiv.org/abs/2010.03633

Parameters
----------
K : int
Maximum polynomial degree for Laplacian.
in_channels : int
Dimension of features on each simplicial cell.
out_channels : int
Dimension of output representation on each simplicial cell.
initialization : string
Initialization method.
"""

def __init__(self, in_channels, out_channels, K):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.K = K

convs = [
Conv(in_channels=in_channels, out_channels=out_channels, update_func="relu")
for _ in range(self.K)
]

self.convs = torch.nn.ModuleList(convs)

self.aggr = Aggregation(aggr_func="sum", update_func="relu")

def reset_parameters(self):
r"""Reset learnable parameters."""

for conv in self.convs:
conv.reset_parameters()

def forward(self, x, laplacian):
r"""Forward pass.

The forward pass was initially proposed in [SNN20]_.
Its equations are adapted from [TNN23]_, graphically illustrated in [PSHM23]_.

.. math::
\begin{align*}
&🟥 \quad m_{y \rightarrow x}^{p, (d \rightarrow d)} = ((H_{d})^p)\_{xy} \cdot h_y^{t,(d)} \cdot \Theta^{t,p}\\
&🟧 \quad m_{x}^{p, (d \rightarrow d)} = \sum_{y \in (\mathcal{L}\_\uparrow + \mathcal{L}\_\downarrow)(x)} m_{y \rightarrow x}^{p, (d \rightarrow d)}\\
&🟧 \quad m_x^{(d \rightarrow d)} = \sum_{p=1}^{P_1} (m_{x}^{p,(d \rightarrow d)})^{p}\\
&🟩 \quad m_x^{(d)} = m_x^{(d \rightarrow d)}\\
&🟦 \quad h_x^{t+1, (d)} = \sigma (m_{x}^{(d)})
\end{align*}

References
----------
.. [SNN20] Stefania Ebli, Michael Defferrard and Gard Spreemann.
Simplicial Neural Networks.
Topological Data Analysis and Beyond workshop at NeurIPS.
https://arxiv.org/abs/2010.03633
.. [TNN23] Equations of Topological Neural Networks.
https://github.com/awesome-tnns/awesome-tnns/
.. [PSHM23] Papillon, Sanborn, Hajij, Miolane.
Architectures of Topological Deep Learning: A Survey on Topological Neural Networks.
(2023) https://arxiv.org/abs/2304.10031.

Parameters
----------
x: torch.Tensor, shape=[n_simplices, in_channels]
Input features on the simplices of the simplicial complex for the given simplicial degree.
laplacian : torch.sparse, shape=[n_simplices, n_simplices]
Simplicial Laplacian matrix for the given simplicial degree.

Returns
-------
_ : torch.Tensor, shape=[n_simplices, out_channels]
Output features on the nodes of the simplicial complex.
"""

outputs = []
laplacian_power = torch.eye(laplacian.shape[0])
outputs.append(self.convs[0](x, laplacian_power))

for i in range(1, self.K):
laplacian_power = torch.mm(laplacian_power, laplacian)

outputs.append(self.convs[i](x, laplacian_power))

x = self.aggr(outputs)
return x
Loading
Loading