Train a Simplex Convolutional Network (SCN) of Rank 2#

This notebook illustrates the SCN layer proposed in [Yang22c]_ for a simplicial complex of rank 2, that is for 0-cells (nodes), 1-cells (edges) and 2-cells (faces) only.

References#

[YSB22]

Ruochen Yang, Frederic Sala, and Paul Bogdan. Efficient Representation Learning for Higher-Order Data with Simplicial Complexes. In Bastian Rieck and Razvan Pascanu, editors, Proceedings of the First Learning on Graphs Conference, volume 198 of Proceedings of Machine Learning Research, pages 13:1–13:21. PMLR, 09–12 Dec 2022a. https://proceedings.mlr.press/v198/yang22a.html.

[13]:
import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split

from topomodelx.nn.simplicial.scn2 import SCN2
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
[14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

Import dataset#

According to the original paper, SCN is good at simplex classification. Thus, I chose shrec_16, a benchmark dataset for 3D mesh classification.

[15]:
shrec, _ = tnx.datasets.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
ys = ys.reshape((100, 1))
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...

done!
[16]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)
The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.

Define neighborhood structures.#

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the normalized Laplacian matrix on nodes, edges, and faces. We also convert the neighborhood structures to torch tensors.

[17]:
laplacian_0s = []
laplacian_1s = []
laplacian_2s = []
for x in simplexes:
    laplacian_0 = x.normalized_laplacian_matrix(rank=0)
    laplacian_1 = x.normalized_laplacian_matrix(rank=1)
    laplacian_2 = x.normalized_laplacian_matrix(rank=2)

    laplacian_0 = from_sparse(laplacian_0)
    laplacian_1 = from_sparse(laplacian_1)
    laplacian_2 = from_sparse(laplacian_2)

    laplacian_0s.append(laplacian_0)
    laplacian_1s.append(laplacian_1)
    laplacian_2s.append(laplacian_2)

Train the Neural Network#

We specify the model with our pre-made neighborhood structures and specify an optimizer.

[18]:
in_channels_0 = x_0s[i_complex].shape[1]
in_channels_1 = x_1s[i_complex].shape[1]
in_channels_2 = x_2s[i_complex].shape[1]
out_channels = 1
[19]:
class Network(torch.nn.Module):
    def __init__(
        self, in_channels_0, in_channels_1, in_channels_2, out_channels, n_layers=2
    ):
        super().__init__()
        self.base_model = SCN2(
            in_channels_0=in_channels_0,
            in_channels_1=in_channels_1,
            in_channels_2=in_channels_2,
            n_layers=n_layers,
        )
        self.lin_0 = torch.nn.Linear(in_channels_0, out_channels)
        self.lin_1 = torch.nn.Linear(in_channels_1, out_channels)
        self.lin_2 = torch.nn.Linear(in_channels_2, out_channels)

    def forward(self, x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2):
        x_0, x_1, x_2 = self.base_model(
            x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2
        )

        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)

        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )
[20]:
n_layers = 2
model = Network(
    in_channels_0, in_channels_1, in_channels_2, out_channels, n_layers=n_layers
)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()
[21]:
test_size = 0.2
x_0s_train, x_0s_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1s_train, x_1s_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2s_train, x_2s_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

laplacian_0s_train, laplacian_0s_test = train_test_split(
    laplacian_0s, test_size=test_size, shuffle=False
)
laplacian_1s_train, laplacian_1s_test = train_test_split(
    laplacian_1s, test_size=test_size, shuffle=False
)
laplacian_2s_train, laplacian_2s_test = train_test_split(
    laplacian_2s, test_size=test_size, shuffle=False
)

y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low number of epochs.

[22]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2, y in zip(
        x_0s_train,
        x_1s_train,
        x_2s_train,
        laplacian_0s_train,
        laplacian_1s_train,
        laplacian_2s_train,
        y_train,
        strict=False,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        laplacian_0, laplacian_1, laplacian_2 = (
            laplacian_0.float().to(device),
            laplacian_1.float().to(device),
            laplacian_2.float().to(device),
        )
        optimizer.zero_grad()
        y_hat = model(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2)
        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2, y in zip(
                x_0s_test,
                x_1s_test,
                x_2s_test,
                laplacian_0s_test,
                laplacian_1s_test,
                laplacian_2s_test,
                y_test,
                strict=False,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                laplacian_0, laplacian_1, laplacian_2 = (
                    laplacian_0.float().to(device),
                    laplacian_1.float().to(device),
                    laplacian_2.float().to(device),
                )
                y_hat = model(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2)
                test_loss = loss_fn(y_hat, y)
            print(f"Test_loss: {test_loss:.4f}", flush=True)
Epoch: 1 loss: 106.0943
Epoch: 2 loss: 85.9616
Epoch: 3 loss: 83.0319
Epoch: 4 loss: 80.9816
Epoch: 5 loss: 79.5344
Epoch: 6 loss: 78.9647
Epoch: 7 loss: 78.0454
Epoch: 8 loss: 77.2352
Epoch: 9 loss: 77.0969
Epoch: 10 loss: 76.4979
Test_loss: 27.4806
Epoch: 11 loss: 76.7194
Epoch: 12 loss: 76.1580
Epoch: 13 loss: 78.0519
Epoch: 14 loss: 75.5705
Epoch: 15 loss: 76.1465
Epoch: 16 loss: 76.5432
Epoch: 17 loss: 78.0466
Epoch: 18 loss: 77.0446
Epoch: 19 loss: 80.5840
Epoch: 20 loss: 75.3278
Test_loss: 35.2363
Epoch: 21 loss: 75.6831
Epoch: 22 loss: 74.5437
Epoch: 23 loss: 75.6770
Epoch: 24 loss: 78.6057
Epoch: 25 loss: 75.3240
Epoch: 26 loss: 74.4333
Epoch: 27 loss: 75.4743
Epoch: 28 loss: 76.0289
Epoch: 29 loss: 75.6414
Epoch: 30 loss: 76.2671
Test_loss: 20.8572
Epoch: 31 loss: 75.7918
Epoch: 32 loss: 74.1627
Epoch: 33 loss: 73.6667
Epoch: 34 loss: 72.9490
Epoch: 35 loss: 72.8358
Epoch: 36 loss: 73.0702
Epoch: 37 loss: 73.5069
Epoch: 38 loss: 73.7774
Epoch: 39 loss: 73.2788
Epoch: 40 loss: 73.7978
Test_loss: 21.3558
Epoch: 41 loss: 74.9096
Epoch: 42 loss: 73.2390
Epoch: 43 loss: 72.4291
Epoch: 44 loss: 73.3779
Epoch: 45 loss: 72.3256
Epoch: 46 loss: 72.9241
Epoch: 47 loss: 72.3715
Epoch: 48 loss: 72.1551
Epoch: 49 loss: 72.7596
Epoch: 50 loss: 72.3155
Test_loss: 13.5706
Epoch: 51 loss: 73.5445
Epoch: 52 loss: 72.3427
Epoch: 53 loss: 74.1711
Epoch: 54 loss: 72.1126
Epoch: 55 loss: 71.3567
Epoch: 56 loss: 69.5716
Epoch: 57 loss: 70.7865
Epoch: 58 loss: 70.4044
Epoch: 59 loss: 69.9258
Epoch: 60 loss: 69.7257
Test_loss: 11.6862
Epoch: 61 loss: 68.8615
Epoch: 62 loss: 69.6709
Epoch: 63 loss: 69.1890
Epoch: 64 loss: 70.6630
Epoch: 65 loss: 68.8225
Epoch: 66 loss: 68.8715
Epoch: 67 loss: 68.1793
Epoch: 68 loss: 68.7412
Epoch: 69 loss: 71.5032
Epoch: 70 loss: 70.2721
Test_loss: 6.6078
Epoch: 71 loss: 68.2701
Epoch: 72 loss: 69.6752
Epoch: 73 loss: 64.3450
Epoch: 74 loss: 62.4395
Epoch: 75 loss: 62.2776
Epoch: 76 loss: 67.0761
Epoch: 77 loss: 63.7860
Epoch: 78 loss: 60.9918
Epoch: 79 loss: 60.4742
Epoch: 80 loss: 60.2551
Test_loss: 0.2153
Epoch: 81 loss: 60.0425
Epoch: 82 loss: 59.4667
Epoch: 83 loss: 58.2843
Epoch: 84 loss: 57.9831
Epoch: 85 loss: 57.4091
Epoch: 86 loss: 56.9356
Epoch: 87 loss: 57.3874
Epoch: 88 loss: 57.6860
Epoch: 89 loss: 56.4505
Epoch: 90 loss: 56.2496
Test_loss: 0.6815
Epoch: 91 loss: 56.8954
Epoch: 92 loss: 55.0596
Epoch: 93 loss: 55.2672
Epoch: 94 loss: 55.1773
Epoch: 95 loss: 55.1011
Epoch: 96 loss: 54.5242
Epoch: 97 loss: 54.0988
Epoch: 98 loss: 54.5479
Epoch: 99 loss: 54.0969
Epoch: 100 loss: 54.7246
Test_loss: 0.0298
[ ]: