Source code for topomodelx.nn.cell.cwn

"""CWN class."""

import torch
import torch.nn.functional as F

from topomodelx.nn.cell.cwn_layer import CWNLayer


[docs] class CWN(torch.nn.Module): """Implementation of a specific version of CW network [1]_. Parameters ---------- in_channels_0 : int Dimension of input features on nodes (0-cells). in_channels_1 : int Dimension of input features on edges (1-cells). in_channels_2 : int Dimension of input features on faces (2-cells). hid_channels : int Dimension of hidden features. n_layers : int Number of CWN layers. **kwargs : optional Additional arguments CWNLayer. References ---------- .. [1] Bodnar, et al. Weisfeiler and Lehman go cellular: CW networks. NeurIPS 2021. https://arxiv.org/abs/2106.12575 """ def __init__( self, in_channels_0, in_channels_1, in_channels_2, hid_channels, n_layers, **kwargs, ): super().__init__() self.proj_0 = torch.nn.Linear(in_channels_0, hid_channels) self.proj_1 = torch.nn.Linear(in_channels_1, hid_channels) self.proj_2 = torch.nn.Linear(in_channels_2, hid_channels) self.layers = torch.nn.ModuleList( CWNLayer( in_channels_0=hid_channels, in_channels_1=hid_channels, in_channels_2=hid_channels, out_channels=hid_channels, **kwargs, ) for _ in range(n_layers) )
[docs] def forward( self, x_0, x_1, x_2, adjacency_0, incidence_2, incidence_1_t, ): """Forward computation through projection, convolutions, linear layers and average pooling. Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, in_channels_0) Input features on the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Input features on the edges (1-cells). x_2 : torch.Tensor, shape = (n_faces, in_channels_2) Input features on the faces (2-cells). adjacency_0 : torch.Tensor, shape = (n_edges, n_edges) Upper-adjacency matrix of rank 1. incidence_2 : torch.Tensor, shape = (n_edges, n_faces) Boundary matrix of rank 2. incidence_1_t : torch.Tensor, shape = (n_edges, n_nodes) Coboundary matrix of rank 1. Returns ------- x_0 : torch.Tensor, shape = (n_nodes, in_channels_0) Final hidden states of the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Final hidden states the edges (1-cells). x_2 : torch.Tensor, shape = (n_edges, in_channels_2) Final hidden states of the faces (2-cells). """ x_0 = F.elu(self.proj_0(x_0)) x_1 = F.elu(self.proj_1(x_1)) x_2 = F.elu(self.proj_2(x_2)) for layer in self.layers: x_1 = layer( x_0, x_1, x_2, adjacency_0, incidence_2, incidence_1_t, ) return x_0, x_1, x_2