Source code for topomodelx.nn.simplicial.scnn
"""Simplicial convolutional neural network implementation for complex classification."""
import torch
from topomodelx.nn.simplicial.scnn_layer import SCNNLayer
[docs]
class SCNN(torch.nn.Module):
"""Simplicial convolutional neural network implementation for complex classification.
Note: At the last layer, we obtain the output on simplcies, e.g., edges.
To perform the complex classification task for this challenge, we consider pass the final output to a linear layer and compute the average.
Parameters
----------
in_channels : int
Dimension of input features.
hidden_channels : int
Dimension of features of hidden layers.
out_channels : int
Dimension of output features.
conv_order_down : int
Order of lower convolution.
conv_order_up : int
Order of upper convolution.
aggr : bool
Whether to aggregate features on the nodes into 1 feature for the whole complex.
Default: False.
n_layers : int
Number of layers.
"""
def __init__(
self,
in_channels,
hidden_channels,
conv_order_down,
conv_order_up,
aggr_norm=False,
update_func=None,
n_layers=2,
):
super().__init__()
# First layer -- initial layer has the in_channels as input, and inter_channels as the output
self.layers = torch.nn.ModuleList(
[
SCNNLayer(
in_channels=in_channels,
out_channels=hidden_channels,
conv_order_down=conv_order_down,
conv_order_up=conv_order_up,
)
]
)
for _ in range(n_layers - 1):
self.layers.append(
SCNNLayer(
in_channels=hidden_channels,
out_channels=hidden_channels,
conv_order_down=conv_order_down,
conv_order_up=conv_order_up,
aggr_norm=aggr_norm,
update_func=update_func,
)
)
[docs]
def forward(self, x, laplacian_down, laplacian_up):
"""Forward computation.
Parameters
----------
x : torch.Tensor, shape = (n_simplices, channels)
Tensor of features node/edge/face.
laplacian_down : torch.Tensor, shape = (n_simplices, n_simplices)
Down Laplacian.
For node features, laplacian_down = None.
laplacian_up: torch.Tensor, shape = (n_edges, n_nodes)
Up Laplacian.
Returns
-------
torch.Tensor, shape = (n_simplices, hidden_channels)
Final hidden representation of one-dimensional cells.
"""
for layer in self.layers:
x = layer(x, laplacian_down, laplacian_up)
return x