Source code for topomodelx.nn.simplicial.hsn

"""High Skip Network Implementation model for binary node classification."""
import torch

from topomodelx.nn.simplicial.hsn_layer import HSNLayer


[docs] class HSN(torch.nn.Module): """High Skip Network Implementation for binary node classification. Parameters ---------- channels : int Dimension of features. n_layers : int Amount of message passing layers. """ def __init__(self, channels, n_layers=2): super().__init__() self.layers = torch.nn.ModuleList( HSNLayer(channels=channels) for _ in range(n_layers) )
[docs] def forward(self, x_0, incidence_1, adjacency_0): """Forward computation. Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, channels) Node features. incidence_1 : torch.Tensor, shape = (n_nodes, n_edges) Boundary matrix of rank 1. adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes) Adjacency matrix (up) of rank 0. Returns ------- torch.Tensor, shape = (n_nodes, channels) Final node hidden representations. """ for layer in self.layers: x_0 = layer(x_0, incidence_1, adjacency_0) return x_0