Source code for topomodelx.nn.simplicial.dist2cycle_layer

"""Dist2Cycle Network Layer."""
import torch
import torch.nn as nn

from topomodelx.base.aggregation import Aggregation


[docs] class Dist2CycleLayer(torch.nn.Module): """Layer of Dist2Cycle [1]_. Parameters ---------- channels : int Dimension of features. """ def __init__( self, channels, ) -> None: super().__init__() self.channels = channels # feature learning self.fc_neigh = nn.Linear(channels, channels, bias=True) self.aggr_on_edges = Aggregation(aggr_func="sum", update_func="relu") # need to support for other update functions like leaky relu # which is main for dist2Cycle
[docs] def reset_parameters(self) -> None: r"""Reset learnable parameters.""" fc_nonlin = "relu" fc_alpha = 0.0 self.fc_neigh.reset_parameters() nn.init.kaiming_uniform_( self.fc_neigh.weight, a=fc_alpha, nonlinearity=fc_nonlin )
[docs] def forward(self, x_e, Linv, adjacency): r"""Forward pass. .. math:: \begin{align*} &🟥 \quad m^{(1 \rightarrow 1)}\_{y \rightarrow x} = (A \odot (I + L\downarrow)^+{xy}) \cdot h_{y}^{t,(1)}\cdot \Theta^t\\ &🟧 \quad m_x^{(1 \rightarrow 1)} = \sum_{y \in \mathcal{L}\_\downarrow(x)} m_{y \rightarrow x}^{(1 \rightarrow 1)}\\ &🟩 \quad m_x^{(1)} = m^{(1 \rightarrow 1)}_x\\ &🟦 \quad h_x^{t+1,(1)} = \sigma(m_{x}^{(1)}) \end{align*} Parameters ---------- x_e: torch.Tensor, shape = (n_nodes, channels) Input features on the edges of the simplicial complex. incidence_1 : torch.sparse, shape = (n_nodes, n_edges) Incidence matrix :math:`B_1` mapping edges to nodes. adjacency_0 : torch.sparse, shape = (n_nodes, n_nodes) Adjacency matrix :math:`A_0^{\uparrow}` mapping nodes to nodes via edges. Returns ------- torch.Tensor, shape = (n_nodes, channels) Output features on the nodes of the simplicial complex. References ---------- .. [1] Papillon, Sanborn, Hajij, Miolane. Equations of topological neural networks (2023). https://github.com/awesome-tnns/awesome-tnns/ """ x_e = adjacency * Linv x_e = self.aggr_on_edges([x_e]) return self.fc_neigh(x_e)