Source code for topomodelx.nn.cell.ccxn

"""CCXN class."""

import torch

from topomodelx.nn.cell.ccxn_layer import CCXNLayer


[docs] class CCXN(torch.nn.Module): """CCXN [1]_. Parameters ---------- in_channels_0 : int Dimension of input features on nodes. in_channels_1 : int Dimension of input features on edges. in_channels_2 : int Dimension of input features on faces. n_layers : int Number of CCXN layers. att : bool Whether to use attention. **kwargs : optional Additional arguments CCXNLayer. References ---------- .. [1] Hajij, Istvan, Zamzmi. Cell complex neural networks. Topological data analysis and beyond workshop at NeurIPS 2020. https://arxiv.org/pdf/2010.00743.pdf """ def __init__( self, in_channels_0, in_channels_1, in_channels_2, n_layers=2, att=False, **kwargs, ): super().__init__() self.layers = torch.nn.ModuleList( CCXNLayer( in_channels_0=in_channels_0, in_channels_1=in_channels_1, in_channels_2=in_channels_2, att=att, **kwargs, ) for _ in range(n_layers) )
[docs] def forward(self, x_0, x_1, adjacency_0, incidence_2_t): """Forward computation through layers. Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, in_channels_0) Input features on the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Input features on the edges (1-cells). adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes) Adjacency matrix of rank 0 (up). incidence_2_t : torch.Tensor, shape = (n_faces, n_edges) Transpose of boundary matrix of rank 2. Returns ------- x_0 : torch.Tensor, shape = (n_nodes, in_channels_0) Final hidden states of the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Final hidden states the edges (1-cells). x_2 : torch.Tensor, shape = (n_faces, in_channels_2) Final hidden states of the faces (2-cells). """ for layer in self.layers: x_0, x_1, x_2 = layer(x_0, x_1, adjacency_0, incidence_2_t) return (x_0, x_1, x_2)