Source code for topomodelx.nn.simplicial.sca_cmps_layer

"""Simplical Complex Autoencoder Layer."""
import torch

from topomodelx.base.aggregation import Aggregation
from topomodelx.base.conv import Conv

[docs] class SCACMPSLayer(torch.nn.Module): """Layer of a Simplicial Complex Autoencoder (SCA) using the Coadjacency Message Passing Scheme (CMPS). Implementation of the SCA layer proposed in [1]_. Notes ----- This is the architecture proposed for complex classification. Parameters ---------- channels_list : list[int] Dimension of features at each dimension. complex_dim : int Highest dimension of chains on the input simplicial complexes. att : bool, default=False Whether to use attention. References ---------- .. [1] Hajij, Zamzmi, Papamarkou, Maroulas, Cai. Simplicial complex autoencoder (2022). .. [2] Papillon, Sanborn, Hajij, Miolane. Architectures of topological deep learning: a survey on topological neural networks (2023). .. [3] Papillon, Sanborn, Hajij, Miolane. Equations of topological neural networks (2023). """ def __init__( self, channels_list, complex_dim, att: bool = False, ) -> None: super().__init__() self.att = att self.dim = complex_dim self.channels_list = channels_list lap_layers = [] inc_layers = [] for i in range(1, complex_dim): conv_layer_lap = Conv( in_channels=channels_list[i], out_channels=channels_list[i], att=att, ) conv_layer_inc = Conv( in_channels=channels_list[i - 1], out_channels=channels_list[i], att=att, ) lap_layers.append(conv_layer_lap) inc_layers.append(conv_layer_inc) self.lap_layers = torch.nn.ModuleList(lap_layers) self.inc_layers = torch.nn.ModuleList(inc_layers) self.aggr = Aggregation( aggr_func="sum", update_func=None, ) self.inter_aggr = Aggregation( aggr_func="mean", update_func="relu", )
[docs] def reset_parameters(self) -> None: r"""Reset parameters of each layer.""" for layer in self.lap_layers: if isinstance(layer, Conv): layer.reset_parameters() for layer in self.inc_layers: if isinstance(layer, Conv): layer.reset_parameters()
[docs] def weight_func(self, x): r"""Weight function for intra aggregation layer according to [1]_. Parameters ---------- x : torch.Tensor Returns ------- torch.Tensor """ return 1 / (1 + torch.exp(-x))
[docs] def intra_aggr(self, x): r"""Based on the use by [1]_. Parameters ---------- x : torch.Tensor Returns ------- torch.Tensor """ x_list = list(torch.split(x, 1, dim=0)) x_weight = self.aggr(x_list) x_weight = torch.matmul(torch.relu(x_weight), x.transpose(1, 0)) x_weight = self.weight_func(x_weight) return x_weight.transpose(1, 0) * x
[docs] def forward(self, x_list, down_lap_list, incidencet_list): r"""Forward pass. The forward pass was initially proposed in [1]_. Its equations are given in [3]_ and graphically illustrated in [2]_. Coadjacency message passing scheme: .. math:: \begin{align*} &🟥 \quad m_{y \rightarrow x}^{(r \rightarrow r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r)},att(h_{x}^{t, (r)}, h_{y}^{t, (r)}),x,y,{\Theta^t}) \qquad \text{where } r'' < r < r'\\ &🟥 \quad m_{y \rightarrow x}^{(r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r'')},att(h_{x}^{t, (r)}, h_{y}^{t, (r'')}),x,y,{\Theta^t})\\ &🟧 \quad m_x^{(r \rightarrow r)} = AGG_{y \in \mathcal{L}\_\downarrow(x)} m_{y \rightarrow x}^{(r \rightarrow r)}\\ &🟧 \quad m_x^{(r'' \rightarrow r)} = AGG_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r'' \rightarrow r)}\\ &🟩 \quad m_x^{(r)} = \text{AGG}\_{\mathcal{N}\_k \in \mathcal{N}}(m_x^{(k)})\\ &🟦 \quad h_{x}^{t+1, (r)} = U(h_x^{t, (r)}, m_{x}^{(r)}) \end{align*} Parameters ---------- x_list : list[torch.Tensor] List of tensors holding the features of each chain at each level. down_lap_list : list[torch.Tensor] List of down laplacian matrices for skeletons from 1 dimension to the dimension of the simplicial complex. incidencet_list : list[torch.Tensor] List of transpose incidence matrices for skeletons from 1 dimension to the dimension of the simplicial complex. Returns ------- list[torch.Tensor] Output for skeletons of each dimension (the node features are left untouched: x_list[0]). """ for i in range(1, self.dim): x_lap = self.lap_layers[i - 1](x_list[i], down_lap_list[i - 1]) x_inc = self.inc_layers[i - 1](x_list[i - 1], incidencet_list[i - 1]) x_lap = self.intra_aggr(x_lap) x_inc = self.intra_aggr(x_inc) x_list[i] = self.inter_aggr([x_lap, x_inc]) return x_list