Source code for topomodelx.nn.simplicial.sccn

"""Simplicial Complex Convolutional Network Implementation for binary node classification."""
import torch

from topomodelx.nn.simplicial.sccn_layer import SCCNLayer


[docs] class SCCN(torch.nn.Module): """Simplicial Complex Convolutional Network Implementation for binary node classification. Parameters ---------- channels : int Dimension of features. max_rank : int Maximum rank of the cells in the simplicial complex. n_layers : int Number of message passing layers. update_func : str Activation function used in aggregation layers. """ def __init__(self, channels, max_rank, n_layers=2, update_func="sigmoid"): super().__init__() self.layers = torch.nn.ModuleList( SCCNLayer( channels=channels, max_rank=max_rank, update_func=update_func, ) for _ in range(n_layers) )
[docs] def forward(self, features, incidences, adjacencies): """Forward computation. Parameters ---------- features : dict[int, torch.Tensor], length=max_rank+1, shape = (n_rank_r_cells, channels) Input features on the cells of the simplicial complex. incidences : dict[int, torch.sparse], length=max_rank, shape = (n_rank_r_minus_1_cells, n_rank_r_cells) Incidence matrices :math:`B_r` mapping r-cells to (r-1)-cells. adjacencies : dict[int, torch.sparse], length=max_rank, shape = (n_rank_r_cells, n_rank_r_cells) Adjacency matrices :math:`H_r` mapping cells to cells via lower and upper cells. Returns ------- Dict of torch.Tensor rank_0 : torch.Tensor Final hidden representations of nodes. rank_1 : torch.Tensor Final hidden representations of edges. rank_2 : torch.Tensor Final hidden representations of triangles. rank_3 : torch.Tensor Final hidden representations of tetrahedra. ... (up to max_rank) """ for layer in self.layers: features = layer(features, incidences, adjacencies) return features