Source code for topomodelx.nn.simplicial.dist2cycle

"""Dist2Cycle model for binary node classification."""
import torch

from topomodelx.nn.simplicial.dist2cycle_layer import Dist2CycleLayer


[docs] class Dist2Cycle(torch.nn.Module): """High Skip Network Implementation for binary node classification. Parameters ---------- channels : int Dimension of features. n_layers : int Amount of message passing layers. """ def __init__(self, channels, n_layers=2): super().__init__() self.layers = torch.nn.ModuleList( Dist2CycleLayer(channels=channels) for _ in range(n_layers) )
[docs] def forward(self, x_1e, Linv, adjacency): """Forward computation. Parameters ---------- x_1e : torch.Tensor, shape = (n_nodes, channels) Node features. Linv : torch.Tensor adjacency : torch.Tensor Returns ------- torch.Tensor, shape = (n_nodes, channels) Final node hidden representations. """ for layer in self.layers: x_1e = layer(x_1e, Linv, adjacency) return x_1e