Source code for topomodelx.nn.simplicial.sca_cmps
"""SCA with CMPS."""
import torch
from topomodelx.nn.simplicial.sca_cmps_layer import SCACMPSLayer
[docs]
class SCACMPS(torch.nn.Module):
"""SCA with CMPS.
Parameters
----------
in_channels_all : list[int]
Dimension of features on each node, edge, simplex, tetahedron,... respectively
complex_dim : int
Highest dimension of simplicial complex feature being trained on.
n_layers : int, default = 2
Amount of message passing layers.
att : bool
Whether to use attention.
"""
def __init__(
self,
in_channels_all,
complex_dim,
n_layers=2,
att=False,
):
super().__init__()
self.n_layers = n_layers
self.in_channels_all = in_channels_all
self.layers = torch.nn.ModuleList(
SCACMPSLayer(in_channels_all, complex_dim, att) for _ in range(n_layers)
)
[docs]
def forward(self, x, laplacian_down_list, incidence_t_list):
"""Forward computation through layers, then linear layers, then avg pooling.
Parameters
----------
x : list[torch.Tensor]
Tensor inputs for each dimension of the complex (nodes, edges, etc.).
laplacian_down_list : list[torch.Tensor]
List of the down laplacian matrix for each dimension in the complex starting at edges.
incidence_t_list : list[torch.Tensor]
List of the transpose incidence matrices for the edges and faces.
Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
"""
for i in range(self.n_layers):
x = self.layers[i](x, laplacian_down_list, incidence_t_list)
return x