Train a Simplicial Complex Autoencoder (SCA) with Coadjacency Message Passing Scheme (CMPS)#

🟥 \(\quad m_{y \rightarrow x}^{(r \rightarrow r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r)},att(h_{x}^{t, (r)}, h_{y}^{t, (r)}),x,y,{\Theta^t}) \qquad \text{where } r'' < r < r'\)

🟥 \(\quad m_{y \rightarrow x}^{(r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r'')},att(h_{x}^{t, (r)}, h_{y}^{t, (r'')}),x,y,{\Theta^t})\)

🟧 \(\quad m_x^{(r \rightarrow r)} = AGG_{y \in \mathcal{L}\_\downarrow(x)} m_{y \rightarrow x}^{(r \rightarrow r)}\)

🟧 \(\quad m_x^{(r'' \rightarrow r)} = AGG_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r'' \rightarrow r)}\)

🟩 \(\quad m_x^{(r)} = \text{AGG}\_{\mathcal{N}\_k \in \mathcal{N}}(m_x^{(k)})\)

🟦 \(\quad h_{x}^{t+1, (r)} = U(h_x^{t, (r)}, m_{x}^{(r)})\)

Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).

[4]:
import numpy as np
import toponetx as tnx
import torch

from topomodelx.nn.simplicial.sca_cmps import SCACMPS
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

If GPUs are available we will make use of them. Otherwise, we will use CPU.

[5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

Import dataset#

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

[6]:
dataset = tnx.datasets.karate_club(complex_type="simplicial")
print(dataset)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
[7]:
dataset.shape
[7]:
(34, 78, 45, 11, 2)

Define neighborhood structures.#

Coadjacency Message Passing Scheme (CMPS):#

This will require features from faces, and edges again, but outputs features on the edges. The first neighborhood matrix will be the level 2 lower Laplacian, \(L_{\downarrow, 2}\), and the second neighborhood matrix will be the transpose of the incidence matrix of the faces, \(B_{2}^T\).

[17]:
laplacian_down_1 = dataset.down_laplacian_matrix(rank=1)
laplacian_down_2 = dataset.down_laplacian_matrix(rank=2)
incidence_1_t = dataset.incidence_matrix(rank=1).T
incidence_2_t = dataset.incidence_matrix(rank=2).T

laplacian_down_1 = from_sparse(laplacian_down_1)
laplacian_down_2 = from_sparse(laplacian_down_2)
incidence_1_t = from_sparse(incidence_1_t)
incidence_2_t = from_sparse(incidence_2_t)

laplacian_down_list = [laplacian_down_1, laplacian_down_2]
incidence_t_list = [incidence_1_t, incidence_2_t]

Import signal#

We retrieve an input signal on the nodes, edges and faces. The signal will have shape \(n_\text{simplicial} \times\) in_channels, where in_channels is the dimension of each simplicial’s feature. Here, we have in_channels = channels_nodes $ = 2$.

[18]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = torch.tensor(np.stack(x_2))
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.

We also pre-define the number output channels of the model, in this case the number of node classes.

[19]:
in_channels = x_0.shape[-1]
out_channels = 2

Define binary labels#

We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert one-hot encode the binary labels, and keep the first four nodes for the purpose of testing.

[20]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

Create the Neural Networks#

Using the SCACMPSLayer class, we create a neural network with a modifiable number of layers each following the CMPS at each level.

[41]:
class Network(torch.nn.Module):
    def __init__(self, in_channels_all, out_channels, complex_dim, n_layers=1):
        super().__init__()
        self.base_model = SCACMPS(
            in_channels_all=in_channels_all,
            complex_dim=complex_dim,
            n_layers=n_layers,
        )
        self.lin0 = torch.nn.Linear(in_channels_all[0], out_channels)

    def forward(self, x, laplacian_down_list, incidence_t_list):
        x = self.base_model(x, laplacian_down_list, incidence_t_list)
        x_0 = self.lin0(x[0])
        return torch.softmax(x_0, dim=1)
[42]:
x = [x_0, x_1, x_2]
in_channels_all = [x_0[0].shape[-1], x_1[0].shape[-1], x_2[0].shape[-1]]
out_channels = 2
complex_dim = 3
n_layers = 1

model = Network(
    in_channels_all=in_channels_all,
    out_channels=out_channels,
    complex_dim=complex_dim,
    n_layers=n_layers,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

Train the Neural Network#

The following cell performs the training, looping over the network for a low number of epochs.

[43]:
test_interval = 10
num_epochs = 200
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x, laplacian_down_list, incidence_t_list)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x, laplacian_down_list, incidence_t_list)
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7272 Train_acc: 0.4333
Epoch: 2 loss: 0.7215 Train_acc: 0.5667
Epoch: 3 loss: 0.7166 Train_acc: 0.5667
Epoch: 4 loss: 0.7126 Train_acc: 0.5667
Epoch: 5 loss: 0.7097 Train_acc: 0.5667
Epoch: 6 loss: 0.7076 Train_acc: 0.5667
Epoch: 7 loss: 0.7062 Train_acc: 0.5667
Epoch: 8 loss: 0.7052 Train_acc: 0.5667
Epoch: 9 loss: 0.7044 Train_acc: 0.5667
Epoch: 10 loss: 0.7037 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 11 loss: 0.7028 Train_acc: 0.5667
Epoch: 12 loss: 0.7018 Train_acc: 0.5667
Epoch: 13 loss: 0.7006 Train_acc: 0.5667
Epoch: 14 loss: 0.6991 Train_acc: 0.5667
Epoch: 15 loss: 0.6974 Train_acc: 0.5667
Epoch: 16 loss: 0.6955 Train_acc: 0.5667
Epoch: 17 loss: 0.6934 Train_acc: 0.5667
Epoch: 18 loss: 0.6913 Train_acc: 0.5667
Epoch: 19 loss: 0.6891 Train_acc: 0.5667
Epoch: 20 loss: 0.6870 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 21 loss: 0.6852 Train_acc: 0.5667
Epoch: 22 loss: 0.6836 Train_acc: 0.5667
Epoch: 23 loss: 0.6822 Train_acc: 0.7000
Epoch: 24 loss: 0.6810 Train_acc: 0.9333
Epoch: 25 loss: 0.6798 Train_acc: 0.9667
Epoch: 26 loss: 0.6784 Train_acc: 1.0000
Epoch: 27 loss: 0.6768 Train_acc: 1.0000
Epoch: 28 loss: 0.6750 Train_acc: 1.0000
Epoch: 29 loss: 0.6731 Train_acc: 0.9667
Epoch: 30 loss: 0.6712 Train_acc: 0.9333
Test_acc: 0.5000
Epoch: 31 loss: 0.6695 Train_acc: 0.9333
Epoch: 32 loss: 0.6679 Train_acc: 0.9333
Epoch: 33 loss: 0.6664 Train_acc: 0.9000
Epoch: 34 loss: 0.6651 Train_acc: 0.9000
Epoch: 35 loss: 0.6637 Train_acc: 0.9000
Epoch: 36 loss: 0.6623 Train_acc: 0.9000
Epoch: 37 loss: 0.6608 Train_acc: 0.9000
Epoch: 38 loss: 0.6593 Train_acc: 0.9333
Epoch: 39 loss: 0.6577 Train_acc: 0.9333
Epoch: 40 loss: 0.6562 Train_acc: 0.9333
Test_acc: 0.7500
Epoch: 41 loss: 0.6547 Train_acc: 0.9667
Epoch: 42 loss: 0.6533 Train_acc: 1.0000
Epoch: 43 loss: 0.6520 Train_acc: 1.0000
Epoch: 44 loss: 0.6507 Train_acc: 1.0000
Epoch: 45 loss: 0.6494 Train_acc: 0.9667
Epoch: 46 loss: 0.6481 Train_acc: 0.9667
Epoch: 47 loss: 0.6468 Train_acc: 0.9667
Epoch: 48 loss: 0.6455 Train_acc: 1.0000
Epoch: 49 loss: 0.6441 Train_acc: 1.0000
Epoch: 50 loss: 0.6429 Train_acc: 1.0000
Test_acc: 1.0000
Epoch: 51 loss: 0.6417 Train_acc: 1.0000
Epoch: 52 loss: 0.6405 Train_acc: 1.0000
Epoch: 53 loss: 0.6393 Train_acc: 1.0000
Epoch: 54 loss: 0.6382 Train_acc: 1.0000
Epoch: 55 loss: 0.6370 Train_acc: 1.0000
Epoch: 56 loss: 0.6358 Train_acc: 1.0000
Epoch: 57 loss: 0.6347 Train_acc: 1.0000
Epoch: 58 loss: 0.6336 Train_acc: 1.0000
Epoch: 59 loss: 0.6325 Train_acc: 0.9667
Epoch: 60 loss: 0.6314 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 61 loss: 0.6304 Train_acc: 0.9667
Epoch: 62 loss: 0.6294 Train_acc: 0.9667
Epoch: 63 loss: 0.6283 Train_acc: 0.9667
Epoch: 64 loss: 0.6273 Train_acc: 0.9667
Epoch: 65 loss: 0.6263 Train_acc: 0.9667
Epoch: 66 loss: 0.6253 Train_acc: 0.9667
Epoch: 67 loss: 0.6244 Train_acc: 0.9667
Epoch: 68 loss: 0.6234 Train_acc: 0.9667
Epoch: 69 loss: 0.6225 Train_acc: 0.9667
Epoch: 70 loss: 0.6216 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 71 loss: 0.6207 Train_acc: 0.9667
Epoch: 72 loss: 0.6198 Train_acc: 0.9667
Epoch: 73 loss: 0.6189 Train_acc: 0.9667
Epoch: 74 loss: 0.6180 Train_acc: 0.9667
Epoch: 75 loss: 0.6172 Train_acc: 0.9667
Epoch: 76 loss: 0.6164 Train_acc: 0.9667
Epoch: 77 loss: 0.6155 Train_acc: 0.9667
Epoch: 78 loss: 0.6147 Train_acc: 0.9667
Epoch: 79 loss: 0.6139 Train_acc: 0.9667
Epoch: 80 loss: 0.6131 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 81 loss: 0.6124 Train_acc: 0.9667
Epoch: 82 loss: 0.6116 Train_acc: 0.9667
Epoch: 83 loss: 0.6108 Train_acc: 0.9667
Epoch: 84 loss: 0.6101 Train_acc: 0.9667
Epoch: 85 loss: 0.6094 Train_acc: 0.9667
Epoch: 86 loss: 0.6086 Train_acc: 0.9667
Epoch: 87 loss: 0.6079 Train_acc: 0.9667
Epoch: 88 loss: 0.6072 Train_acc: 0.9667
Epoch: 89 loss: 0.6065 Train_acc: 0.9667
Epoch: 90 loss: 0.6058 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 91 loss: 0.6052 Train_acc: 0.9667
Epoch: 92 loss: 0.6045 Train_acc: 0.9667
Epoch: 93 loss: 0.6038 Train_acc: 0.9667
Epoch: 94 loss: 0.6032 Train_acc: 0.9667
Epoch: 95 loss: 0.6026 Train_acc: 0.9667
Epoch: 96 loss: 0.6019 Train_acc: 0.9667
Epoch: 97 loss: 0.6013 Train_acc: 0.9667
Epoch: 98 loss: 0.6007 Train_acc: 0.9667
Epoch: 99 loss: 0.6001 Train_acc: 0.9667
Epoch: 100 loss: 0.5995 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 101 loss: 0.5989 Train_acc: 0.9667
Epoch: 102 loss: 0.5983 Train_acc: 0.9667
Epoch: 103 loss: 0.5977 Train_acc: 0.9667
Epoch: 104 loss: 0.5972 Train_acc: 0.9667
Epoch: 105 loss: 0.5966 Train_acc: 0.9667
Epoch: 106 loss: 0.5961 Train_acc: 0.9667
Epoch: 107 loss: 0.5955 Train_acc: 0.9667
Epoch: 108 loss: 0.5950 Train_acc: 0.9667
Epoch: 109 loss: 0.5944 Train_acc: 0.9667
Epoch: 110 loss: 0.5939 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 111 loss: 0.5934 Train_acc: 0.9667
Epoch: 112 loss: 0.5929 Train_acc: 0.9667
Epoch: 113 loss: 0.5924 Train_acc: 0.9667
Epoch: 114 loss: 0.5919 Train_acc: 0.9667
Epoch: 115 loss: 0.5914 Train_acc: 0.9667
Epoch: 116 loss: 0.5909 Train_acc: 0.9667
Epoch: 117 loss: 0.5904 Train_acc: 0.9667
Epoch: 118 loss: 0.5899 Train_acc: 0.9667
Epoch: 119 loss: 0.5894 Train_acc: 0.9667
Epoch: 120 loss: 0.5890 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 121 loss: 0.5885 Train_acc: 0.9667
Epoch: 122 loss: 0.5881 Train_acc: 0.9667
Epoch: 123 loss: 0.5876 Train_acc: 0.9667
Epoch: 124 loss: 0.5872 Train_acc: 0.9667
Epoch: 125 loss: 0.5867 Train_acc: 0.9667
Epoch: 126 loss: 0.5863 Train_acc: 0.9667
Epoch: 127 loss: 0.5859 Train_acc: 0.9667
Epoch: 128 loss: 0.5854 Train_acc: 0.9667
Epoch: 129 loss: 0.5850 Train_acc: 0.9667
Epoch: 130 loss: 0.5846 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 131 loss: 0.5842 Train_acc: 0.9667
Epoch: 132 loss: 0.5838 Train_acc: 0.9667
Epoch: 133 loss: 0.5834 Train_acc: 0.9667
Epoch: 134 loss: 0.5830 Train_acc: 0.9667
Epoch: 135 loss: 0.5826 Train_acc: 0.9667
Epoch: 136 loss: 0.5822 Train_acc: 0.9667
Epoch: 137 loss: 0.5818 Train_acc: 0.9667
Epoch: 138 loss: 0.5814 Train_acc: 0.9667
Epoch: 139 loss: 0.5810 Train_acc: 0.9667
Epoch: 140 loss: 0.5806 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 141 loss: 0.5803 Train_acc: 0.9667
Epoch: 142 loss: 0.5799 Train_acc: 0.9667
Epoch: 143 loss: 0.5795 Train_acc: 0.9667
Epoch: 144 loss: 0.5792 Train_acc: 0.9667
Epoch: 145 loss: 0.5788 Train_acc: 0.9667
Epoch: 146 loss: 0.5785 Train_acc: 0.9667
Epoch: 147 loss: 0.5781 Train_acc: 0.9667
Epoch: 148 loss: 0.5778 Train_acc: 0.9667
Epoch: 149 loss: 0.5774 Train_acc: 0.9667
Epoch: 150 loss: 0.5771 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 151 loss: 0.5768 Train_acc: 0.9667
Epoch: 152 loss: 0.5764 Train_acc: 0.9667
Epoch: 153 loss: 0.5761 Train_acc: 0.9667
Epoch: 154 loss: 0.5758 Train_acc: 0.9667
Epoch: 155 loss: 0.5755 Train_acc: 0.9667
Epoch: 156 loss: 0.5751 Train_acc: 0.9667
Epoch: 157 loss: 0.5748 Train_acc: 0.9667
Epoch: 158 loss: 0.5745 Train_acc: 0.9667
Epoch: 159 loss: 0.5742 Train_acc: 0.9667
Epoch: 160 loss: 0.5739 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 161 loss: 0.5736 Train_acc: 0.9667
Epoch: 162 loss: 0.5733 Train_acc: 0.9667
Epoch: 163 loss: 0.5730 Train_acc: 0.9667
Epoch: 164 loss: 0.5727 Train_acc: 0.9667
Epoch: 165 loss: 0.5724 Train_acc: 0.9667
Epoch: 166 loss: 0.5721 Train_acc: 0.9667
Epoch: 167 loss: 0.5718 Train_acc: 0.9667
Epoch: 168 loss: 0.5715 Train_acc: 0.9667
Epoch: 169 loss: 0.5712 Train_acc: 0.9667
Epoch: 170 loss: 0.5709 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 171 loss: 0.5707 Train_acc: 0.9667
Epoch: 172 loss: 0.5704 Train_acc: 0.9667
Epoch: 173 loss: 0.5701 Train_acc: 0.9667
Epoch: 174 loss: 0.5698 Train_acc: 0.9667
Epoch: 175 loss: 0.5696 Train_acc: 0.9667
Epoch: 176 loss: 0.5693 Train_acc: 0.9667
Epoch: 177 loss: 0.5690 Train_acc: 0.9667
Epoch: 178 loss: 0.5688 Train_acc: 0.9667
Epoch: 179 loss: 0.5685 Train_acc: 0.9667
Epoch: 180 loss: 0.5683 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 181 loss: 0.5680 Train_acc: 0.9667
Epoch: 182 loss: 0.5678 Train_acc: 0.9667
Epoch: 183 loss: 0.5675 Train_acc: 0.9667
Epoch: 184 loss: 0.5673 Train_acc: 0.9667
Epoch: 185 loss: 0.5670 Train_acc: 0.9667
Epoch: 186 loss: 0.5668 Train_acc: 0.9667
Epoch: 187 loss: 0.5665 Train_acc: 0.9667
Epoch: 188 loss: 0.5663 Train_acc: 0.9667
Epoch: 189 loss: 0.5660 Train_acc: 0.9667
Epoch: 190 loss: 0.5658 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 191 loss: 0.5656 Train_acc: 0.9667
Epoch: 192 loss: 0.5653 Train_acc: 0.9667
Epoch: 193 loss: 0.5651 Train_acc: 0.9667
Epoch: 194 loss: 0.5649 Train_acc: 0.9667
Epoch: 195 loss: 0.5647 Train_acc: 0.9667
Epoch: 196 loss: 0.5644 Train_acc: 0.9667
Epoch: 197 loss: 0.5642 Train_acc: 0.9667
Epoch: 198 loss: 0.5640 Train_acc: 0.9667
Epoch: 199 loss: 0.5638 Train_acc: 0.9667
Epoch: 200 loss: 0.5635 Train_acc: 0.9667
Test_acc: 1.0000
[ ]: