Source code for topomodelx.nn.simplicial.scnn

"""Simplicial convolutional neural network implementation for complex classification."""
import torch

from topomodelx.nn.simplicial.scnn_layer import SCNNLayer

[docs] class SCNN(torch.nn.Module): """Simplicial convolutional neural network implementation for complex classification. Note: At the last layer, we obtain the output on simplcies, e.g., edges. To perform the complex classification task for this challenge, we consider pass the final output to a linear layer and compute the average. Parameters ---------- in_channels : int Dimension of input features. hidden_channels : int Dimension of features of hidden layers. out_channels : int Dimension of output features. conv_order_down : int Order of lower convolution. conv_order_up : int Order of upper convolution. aggr : bool Whether to aggregate features on the nodes into 1 feature for the whole complex. Default: False. n_layers : int Number of layers. """ def __init__( self, in_channels, hidden_channels, conv_order_down, conv_order_up, aggr_norm=False, update_func=None, n_layers=2, ): super().__init__() # First layer -- initial layer has the in_channels as input, and inter_channels as the output self.layers = torch.nn.ModuleList( [ SCNNLayer( in_channels=in_channels, out_channels=hidden_channels, conv_order_down=conv_order_down, conv_order_up=conv_order_up, ) ] ) for _ in range(n_layers - 1): self.layers.append( SCNNLayer( in_channels=hidden_channels, out_channels=hidden_channels, conv_order_down=conv_order_down, conv_order_up=conv_order_up, aggr_norm=aggr_norm, update_func=update_func, ) )
[docs] def forward(self, x, laplacian_down, laplacian_up): """Forward computation. Parameters ---------- x : torch.Tensor, shape = (n_simplices, channels) Tensor of features node/edge/face. laplacian_down : torch.Tensor, shape = (n_simplices, n_simplices) Down Laplacian. For node features, laplacian_down = None. laplacian_up: torch.Tensor, shape = (n_edges, n_nodes) Up Laplacian. Returns ------- torch.Tensor, shape = (n_simplices, hidden_channels) Final hidden representation of one-dimensional cells. """ for layer in self.layers: x = layer(x, laplacian_down, laplacian_up) return x