Source code for topomodelx.nn.simplicial.scn2

"""Simplex Convolutional Network Implementation for binary node classification."""
import torch

from topomodelx.nn.simplicial.scn2_layer import SCN2Layer


[docs] class SCN2(torch.nn.Module): """Simplex Convolutional Network Implementation for binary node classification. Parameters ---------- in_channels_0 : int Dimension of input features on nodes. in_channels_1 : int Dimension of input features on edges. in_channels_2 : int Dimension of input features on faces. n_layers : int Amount of message passing layers. """ def __init__(self, in_channels_0, in_channels_1, in_channels_2, n_layers=2): super().__init__() self.layers = torch.nn.ModuleList( SCN2Layer( in_channels_0=in_channels_0, in_channels_1=in_channels_1, in_channels_2=in_channels_2, ) for _ in range(n_layers) )
[docs] def forward(self, x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2): """Forward computation. Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, channels) Node features. x_1 : torch.Tensor, shape = (n_edges, channels) Edge features. x_2 : torch.Tensor, shape = (n_faces, channels) Face features. Returns ------- x_0 : torch.Tensor, shape = (n_nodes, channels) Final node hidden states. x_1 : torch.Tensor, shape = (n_nodes, channels) Final edge hidden states. x_2 : torch.Tensor, shape = (n_nodes, channels) Final face hidden states. """ for layer in self.layers: x_0, x_1, x_2 = layer(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2) return x_0, x_1, x_2