Source code for topomodelx.nn.simplicial.scn2_layer

"""Simplex Convolutional Network (SCN) Layer [Yang et al. LoG 2022]."""
import torch

from topomodelx.base.conv import Conv


[docs] class SCN2Layer(torch.nn.Module): """Layer of a Simplex Convolutional Network (SCN). Implementation of the SCN layer proposed in [1]_ for a simplicial complex of rank 2, that is for 0-cells (nodes), 1-cells (edges) and 2-cells (faces) only. This layer corresponds to the rightmost tensor diagram labeled Yang22c in Figure 11 of [PSHM23]_. Parameters ---------- in_channels_0 : int Dimension of input features on nodes (0-cells). in_channels_1 : int Dimension of input features on edges (1-cells). in_channels_2 : int Dimension of input features on faces (2-cells). See Also -------- topomodelx.nn.simplicial.sccn_layer.SCCNLayer : SCCN layer Simplicial Complex Convolutional Network (SCCN) layer proposed in [1]_. The difference between SCCN and SCN is that: - SCN passes messages between cells of the same rank, - SCCN passes messages between cells of the same ranks, one rank above and one rank below. Notes ----- This architecture is proposed for simplicial complex classification. References ---------- .. [1] Yang, Sala and Bogdan. Efficient representation learning for higher-order data with simplicial complexes (2022). https://proceedings.mlr.press/v198/yang22a.html. .. [2] Papillon, Sanborn, Hajij, Miolane. Equations of topological neural networks (2023). https://github.com/awesome-tnns/awesome-tnns/ .. [3] Papillon, Sanborn, Hajij, Miolane. Architectures of topological deep learning: a survey on topological neural networks (2023). https://arxiv.org/abs/2304.10031. """ def __init__(self, in_channels_0, in_channels_1, in_channels_2) -> None: super().__init__() self.conv_0_to_0 = Conv(in_channels=in_channels_0, out_channels=in_channels_0) self.conv_1_to_1 = Conv(in_channels=in_channels_1, out_channels=in_channels_1) self.conv_2_to_2 = Conv(in_channels=in_channels_2, out_channels=in_channels_2)
[docs] def reset_parameters(self) -> None: r"""Reset learnable parameters.""" self.conv_0_to_0.reset_parameters() self.conv_1_to_1.reset_parameters() self.conv_2_to_2.reset_parameters()
[docs] def forward(self, x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2): r"""Forward pass (see [2]_ and [3]_). .. math:: \begin{align*} &🟥 \quad m^{(r \rightarrow r)}\_{y \rightarrow x} = (2I + H_r)\_{{xy}} \cdot h_{y}^{t,(1)}\cdot \Theta^t\\ &🟧 \quad m_x^{(1 \rightarrow 1)} = \sum_{y \in (\mathcal{L}\_\downarrow+\mathcal{L}\_\uparrow)(x)} m_{y \rightarrow x}^{(1 \rightarrow 1)}\\ &🟩 \quad m_x^{(1)} = m^{(1 \rightarrow 1)}_x\\ &🟦 \quad h_x^{t+1,(1)} = \sigma(m_{x}^{(1)}) \end{align*} Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, node_features) Input features on the nodes of the simplicial complex. x_1 : torch.Tensor, shape = (n_edges, edge_features) Input features on the edges of the simplicial complex. x_2 : torch.Tensor, shape = (n_faces, face_features) Input features on the faces of the simplicial complex. laplacian_0 : torch.sparse, shape = (n_nodes, n_nodes) Normalized Hodge Laplacian matrix = L_upper + L_lower. laplacian_1 : torch.sparse, shape = (n_edges, n_edges) Normalized Hodge Laplacian matrix. laplacian_2 : torch.sparse, shape = (n_faces, n_faces) Normalized Hodge Laplacian matrix. Returns ------- torch.Tensor, shape = (n_nodes, channels) Output features on the nodes of the simplicial complex. """ x_0 = self.conv_0_to_0(x_0, laplacian_0) x_0 = torch.nn.functional.relu(x_0) x_1 = self.conv_1_to_1(x_1, laplacian_1) x_1 = torch.nn.functional.relu(x_1) x_2 = self.conv_2_to_2(x_2, laplacian_2) x_2 = torch.nn.functional.relu(x_2) return x_0, x_1, x_2