Source code for topomodelx.nn.simplicial.sccnn

"""SCCNN implementation for complex classification."""
import torch

from topomodelx.nn.simplicial.sccnn_layer import SCCNNLayer


[docs] class SCCNN(torch.nn.Module): """SCCNN implementation for complex classification. Note: In this task, we can consider the output on any order of simplices for the classification task, which of course can be amended by a readout layer. Parameters ---------- in_channels_all: tuple of int Dimension of input features on (nodes, edges, faces). hidden_channels_all: tuple of int Dimension of features of hidden layers on (nodes, edges, faces). conv_order: int Order of convolutions, we consider the same order for all convolutions. sc_order: int Order of simplicial complex. aggr_norm: bool Whether to normalize the aggregation. update_func: str Update function for the simplicial complex convolution. n_layers: int Number of layers. """ def __init__( self, in_channels_all, hidden_channels_all, conv_order, sc_order, aggr_norm=False, update_func=None, n_layers=2, ): super().__init__() # first layer # we use an MLP to map the features on simplices of different dimensions to the same dimension self.in_linear_0 = torch.nn.Linear(in_channels_all[0], hidden_channels_all[0]) self.in_linear_1 = torch.nn.Linear(in_channels_all[1], hidden_channels_all[1]) self.in_linear_2 = torch.nn.Linear(in_channels_all[2], hidden_channels_all[2]) self.layers = torch.nn.ModuleList( SCCNNLayer( in_channels=hidden_channels_all, out_channels=hidden_channels_all, conv_order=conv_order, sc_order=sc_order, aggr_norm=aggr_norm, update_func=update_func, ) for _ in range(n_layers) )
[docs] def forward(self, x_all, laplacian_all, incidence_all): """Forward computation. Parameters ---------- x_all : tuple of tensors Tuple of feature tensors (node, edge, face). Each entry shape = (n_simplices, channels). laplacian_all : tuple of tensors Tuple of Laplacian tensors (graph laplacian L0, down edge laplacian L1_d, upper edge laplacian L1_u, face laplacian L2). Each entry shape = (n_simplices,n_simplices). incidence_all : tuple of tensors Tuple of order 1 and 2 incidence matrices. Shape of B1 = [n_nodes, n_edges]. Shape of B2 = [n_edges, n_faces]. Returns ------- x_all : tuple of tensors Tuple of final hidden state tensors (node, edge, face). Each entry shape = (n_simplices, channels). """ x_0, x_1, x_2 = x_all in_x_0 = self.in_linear_0(x_0) in_x_1 = self.in_linear_1(x_1) in_x_2 = self.in_linear_2(x_2) # Forward through SCCNN x_all = (in_x_0, in_x_1, in_x_2) for layer in self.layers: x_all = layer(x_all, laplacian_all, incidence_all) return x_all