Source code for topomodelx.nn.simplicial.scconv
"""Simplicial 2-Complex Convolutional Network Implementation for binary node classification."""
import torch
from topomodelx.nn.simplicial.scconv_layer import SCConvLayer
[docs]
class SCConv(torch.nn.Module):
"""Simplicial 2-Complex Convolutional Network Implementation for binary node classification.
Parameters
----------
node_channels : int
Dimension of node (0-cells) features.
edge_channels : int
Dimension of edge (1-cells) features.
face_channels : int
Dimension of face (2-cells) features.
n_layers : int
Number of message passing layers.
n_classes : int
Number of classes.
update_func : str
Activation function used in aggregation layers.
"""
def __init__(
self, node_channels, edge_channels=None, face_channels=None, n_layers=2
):
super().__init__()
self.node_channels = node_channels
self.edge_channels = node_channels if edge_channels is None else edge_channels
self.face_channels = node_channels if face_channels is None else face_channels
self.n_layers = n_layers
self.layers = torch.nn.ModuleList(
SCConvLayer(
node_channels=self.node_channels,
edge_channels=self.edge_channels,
face_channels=self.face_channels,
)
for _ in range(n_layers)
)
[docs]
def forward(
self,
x_0,
x_1,
x_2,
incidence_1,
incidence_1_norm,
incidence_2,
incidence_2_norm,
adjacency_up_0_norm,
adjacency_up_1_norm,
adjacency_down_1_norm,
adjacency_down_2_norm,
):
"""Forward computation.
Parameters
----------
x_0: torch.Tensor, shape = (n_nodes, node_channels)
Input features on the nodes of the simplicial complex.
x_1: torch.Tensor, shape = (n_edges, edge_channels)
Input features on the edges of the simplicial complex.
x_2: torch.Tensor, shape = (n_faces, face_channels)
Input features on the faces of the simplicial complex.
incidence_1: torch.Tensor, shape = (n_faces, channels)
Incidence matrix of rank 1 :math:`B_1`.
incidence_1_norm: torch.Tensor
Normalized incidence matrix of rank 1 :math:`B^{~}_1`.
incidence_2: torch.Tensor
Incidence matrix of rank 2 :math:`B_2`.
incidence_2_norm: torch.Tensor
Normalized incidence matrix of rank 2 :math:`B^{~}_2`.
adjacency_up_0_norm: torch.Tensor
Normalized upper adjacency matrix of rank 0.
adjacency_up_1_norm: torch.Tensor
Normalized upper adjacency matrix of rank 1.
adjacency_down_1_norm: torch.Tensor
Normalized down adjacency matrix of rank 1.
adjacency_down_2_norm: torch.Tensor
Normalized down adjacency matrix of rank 2.
Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
"""
for i in range(self.n_layers):
x_0, x_1, x_2 = self.layers[i](
x_0,
x_1,
x_2,
incidence_1,
incidence_1_norm,
incidence_2,
incidence_2_norm,
adjacency_up_0_norm,
adjacency_up_1_norm,
adjacency_down_1_norm,
adjacency_down_2_norm,
)
return x_0, x_1, x_2