Train a Simplicial High-Skip Network (HSN)#
In this notebook, we will create and train a High Skip Network in the simplicial complex domain, as proposed in the paper by Hajij et. al : High Skip Networks: A Higher Order Generalization of Skip Connections (2022).
We train the model to perform binary node classification using the KarateClub benchmark dataset.
The equations of one layer of this neural network are given by:
🟥 \(\quad m_{{y \rightarrow z}}^{(0 \rightarrow 0)} = \sigma ((A_{\uparrow,0})_{xy} \cdot h^{t,(0)}_y \cdot \Theta^{t,(0)1})\) (level 1)
🟥 \(\quad m_{z \rightarrow x}^{(0 \rightarrow 0)} = (A_{\uparrow,0})_{xy} \cdot m_{y \rightarrow z}^{(0 \rightarrow 0)} \cdot \Theta^{t,(0)2}\) (level 2)
🟥 \(\quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)} = \sigma((B_1^T)_{zy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow 1)})\) (level 1)
🟥 \(\quad m_{z \rightarrow x)}^{(1 \rightarrow 0)} = (B_1)_{xz} \cdot m_{z \rightarrow x}^{(0 \rightarrow 1)} \cdot \Theta^{t, (1 \rightarrow 0)}\) (level 2)
🟧 \(\quad m_{x}^{(0 \rightarrow 0)} = \sum_{z \in \mathcal{L}_\uparrow(x)} m_{z \rightarrow x}^{(0 \rightarrow 0)}\)
🟧 \(\quad m_{x}^{(1 \rightarrow 0)} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}\)
🟩 \(\quad m_x^{(0)} = m_x^{(0 \rightarrow 0)} + m_x^{(1 \rightarrow 0)}\)
🟦 \(\quad h_x^{t+1,(0)} = I(m_x^{(0)})\)
Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).
[1]:
import numpy as np
import toponetx as tnx
import torch
from topomodelx.nn.simplicial.hsn import HSN
from topomodelx.utils.sparse import from_sparse
%load_ext autoreload
%autoreload 2
Pre-processing#
Import dataset#
The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.
We must first lift our graph dataset into the simplicial complex domain.
[2]:
dataset = tnx.datasets.karate_club(complex_type="simplicial")
print(dataset)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
Define neighborhood structures.#
Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the boundary matrix (or incidence matrix) \(B_1\) and the adjacency matrix \(A_{\uparrow,0}\) on the nodes. For a santiy check, we show that the shape of the \(B_1 = n_\text{nodes} \times n_\text{edges}\) and \(A_{\uparrow,0} = n_\text{nodes} \times n_\text{nodes}\). We also convert the neighborhood structures to torch tensors.
[3]:
incidence_1 = dataset.incidence_matrix(rank=1)
adjacency_0 = dataset.adjacency_matrix(rank=0)
incidence_1 = from_sparse(incidence_1)
adjacency_0 = from_sparse(adjacency_0)
print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The adjacency matrix A0 has shape: {adjacency_0.shape}.")
The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix A0 has shape: torch.Size([34, 34]).
Import signal#
Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape \(n_\text{nodes} \times\) in_channels, where in_channels is the dimension of each cell’s feature. Here, we have in_channels = channels_nodes $ = 34$. This is because the Karate dataset encodes the identity of each of the 34 nodes as a one hot encoder.
[4]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
[5]:
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
There are 34 nodes with features of dimension 2.
To load edge features, this is how we would do it (note that we will not use these features for this model, and this serves simply as a demonstration).
[6]:
x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = np.stack(x_1)
[7]:
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
There are 78 edges with features of dimension 2.
Similarly for face features:
[8]:
x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = np.stack(x_2)
[9]:
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 45 faces with features of dimension 2.
Define binary labels#
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.
We convert the binary labels into one-hot encoder form, and keep the first four nodes’ true labels for the purpose of testing.
[10]:
y = np.array(
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
0,
0,
1,
1,
0,
1,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)
Create the Neural Network#
Using the HSN class, we create our neural network with stacked layers. Given the considered dataset and task (Karate Club, node classification), a linear layer at the end produces an output with shape \(n_\text{nodes} \times 2\), so we can compare with our binary labels.
[11]:
class Network(torch.nn.Module):
def __init__(self, channels, out_channels, n_layers=2):
super().__init__()
self.base_model = HSN(
channels=channels,
n_layers=n_layers,
)
self.linear = torch.nn.Linear(channels, out_channels)
def forward(self, x, incidence_1, adjacency_0):
x = self.base_model(x, incidence_1, adjacency_0)
x = self.linear(x)
return torch.softmax(x, dim=1)
[12]:
out_channels = 2
n_layers = 2
model = Network(
channels=channels_nodes,
out_channels=out_channels,
n_layers=2,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Train the Neural Network#
The following cell performs the training, looping over the network for a low number of epochs.
[13]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
epoch_loss = []
model.train()
optimizer.zero_grad()
y_hat = model(x_0, incidence_1, adjacency_0)
loss = torch.nn.functional.binary_cross_entropy_with_logits(
y_hat[: len(y_train)].float(), y_train.float()
)
epoch_loss.append(loss.item())
loss.backward()
optimizer.step()
y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
accuracy = (y_pred[-len(y_train) :] == y_train).all(dim=1).float().mean().item()
print(
f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
flush=True,
)
if epoch_i % test_interval == 0:
with torch.no_grad():
y_hat_test = model(x_0, incidence_1, adjacency_0)
y_pred_test = torch.where(
y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
)
test_accuracy = (
torch.eq(y_pred_test[-len(y_test) :], y_test)
.all(dim=1)
.float()
.mean()
.item()
)
print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7200 Train_acc: 0.5667
Epoch: 2 loss: 0.7183 Train_acc: 0.5667
Epoch: 3 loss: 0.7167 Train_acc: 0.5667
Epoch: 4 loss: 0.7153 Train_acc: 0.5667
Epoch: 5 loss: 0.7142 Train_acc: 0.5667
Epoch: 6 loss: 0.7133 Train_acc: 0.5667
Epoch: 7 loss: 0.7125 Train_acc: 0.5667
Epoch: 8 loss: 0.7118 Train_acc: 0.5667
Epoch: 9 loss: 0.7111 Train_acc: 0.5667
Epoch: 10 loss: 0.7105 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 11 loss: 0.7099 Train_acc: 0.5667
Epoch: 12 loss: 0.7093 Train_acc: 0.5667
Epoch: 13 loss: 0.7087 Train_acc: 0.5667
Epoch: 14 loss: 0.7081 Train_acc: 0.5667
Epoch: 15 loss: 0.7076 Train_acc: 0.5667
Epoch: 16 loss: 0.7070 Train_acc: 0.5667
Epoch: 17 loss: 0.7066 Train_acc: 0.5667
Epoch: 18 loss: 0.7061 Train_acc: 0.6000
Epoch: 19 loss: 0.7058 Train_acc: 0.5667
Epoch: 20 loss: 0.7054 Train_acc: 0.5667
Test_acc: 0.2500
Epoch: 21 loss: 0.7050 Train_acc: 0.5667
Epoch: 22 loss: 0.7046 Train_acc: 0.5667
Epoch: 23 loss: 0.7042 Train_acc: 0.5667
Epoch: 24 loss: 0.7037 Train_acc: 0.5667
Epoch: 25 loss: 0.7033 Train_acc: 0.5667
Epoch: 26 loss: 0.7029 Train_acc: 0.5667
Epoch: 27 loss: 0.7025 Train_acc: 0.5667
Epoch: 28 loss: 0.7021 Train_acc: 0.5667
Epoch: 29 loss: 0.7017 Train_acc: 0.5333
Epoch: 30 loss: 0.7013 Train_acc: 0.5333
Test_acc: 0.5000
Epoch: 31 loss: 0.7009 Train_acc: 0.5667
Epoch: 32 loss: 0.7005 Train_acc: 0.5667
Epoch: 33 loss: 0.7001 Train_acc: 0.5667
Epoch: 34 loss: 0.6997 Train_acc: 0.5333
Epoch: 35 loss: 0.6993 Train_acc: 0.5333
Epoch: 36 loss: 0.6989 Train_acc: 0.5667
Epoch: 37 loss: 0.6985 Train_acc: 0.5667
Epoch: 38 loss: 0.6981 Train_acc: 0.5333
Epoch: 39 loss: 0.6977 Train_acc: 0.5333
Epoch: 40 loss: 0.6973 Train_acc: 0.5667
Test_acc: 0.5000
Epoch: 41 loss: 0.6969 Train_acc: 0.5667
Epoch: 42 loss: 0.6966 Train_acc: 0.5667
Epoch: 43 loss: 0.6962 Train_acc: 0.5667
Epoch: 44 loss: 0.6958 Train_acc: 0.5667
Epoch: 45 loss: 0.6955 Train_acc: 0.5667
Epoch: 46 loss: 0.6951 Train_acc: 0.5667
Epoch: 47 loss: 0.6948 Train_acc: 0.5667
Epoch: 48 loss: 0.6945 Train_acc: 0.5667
Epoch: 49 loss: 0.6942 Train_acc: 0.5667
Epoch: 50 loss: 0.6938 Train_acc: 0.5667
Test_acc: 0.5000
Epoch: 51 loss: 0.6935 Train_acc: 0.5667
Epoch: 52 loss: 0.6932 Train_acc: 0.5667
Epoch: 53 loss: 0.6929 Train_acc: 0.5667
Epoch: 54 loss: 0.6927 Train_acc: 0.5667
Epoch: 55 loss: 0.6924 Train_acc: 0.5667
Epoch: 56 loss: 0.6921 Train_acc: 0.5667
Epoch: 57 loss: 0.6919 Train_acc: 0.5667
Epoch: 58 loss: 0.6916 Train_acc: 0.5667
Epoch: 59 loss: 0.6914 Train_acc: 0.5667
Epoch: 60 loss: 0.6911 Train_acc: 0.5667
Test_acc: 0.5000
Epoch: 61 loss: 0.6909 Train_acc: 0.5667
Epoch: 62 loss: 0.6907 Train_acc: 0.5667
Epoch: 63 loss: 0.6905 Train_acc: 0.5667
Epoch: 64 loss: 0.6903 Train_acc: 0.6000
Epoch: 65 loss: 0.6901 Train_acc: 0.6000
Epoch: 66 loss: 0.6899 Train_acc: 0.6000
Epoch: 67 loss: 0.6897 Train_acc: 0.6000
Epoch: 68 loss: 0.6895 Train_acc: 0.6000
Epoch: 69 loss: 0.6893 Train_acc: 0.6000
Epoch: 70 loss: 0.6892 Train_acc: 0.6000
Test_acc: 0.5000
Epoch: 71 loss: 0.6890 Train_acc: 0.6000
Epoch: 72 loss: 0.6888 Train_acc: 0.6000
Epoch: 73 loss: 0.6887 Train_acc: 0.6000
Epoch: 74 loss: 0.6885 Train_acc: 0.6000
Epoch: 75 loss: 0.6884 Train_acc: 0.6000
Epoch: 76 loss: 0.6883 Train_acc: 0.6000
Epoch: 77 loss: 0.6881 Train_acc: 0.6000
Epoch: 78 loss: 0.6880 Train_acc: 0.6000
Epoch: 79 loss: 0.6879 Train_acc: 0.6000
Epoch: 80 loss: 0.6878 Train_acc: 0.6000
Test_acc: 0.5000
Epoch: 81 loss: 0.6876 Train_acc: 0.6000
Epoch: 82 loss: 0.6875 Train_acc: 0.6000
Epoch: 83 loss: 0.6874 Train_acc: 0.6000
Epoch: 84 loss: 0.6873 Train_acc: 0.6000
Epoch: 85 loss: 0.6872 Train_acc: 0.6000
Epoch: 86 loss: 0.6871 Train_acc: 0.6000
Epoch: 87 loss: 0.6870 Train_acc: 0.6000
Epoch: 88 loss: 0.6869 Train_acc: 0.6000
Epoch: 89 loss: 0.6868 Train_acc: 0.6000
Epoch: 90 loss: 0.6868 Train_acc: 0.6000
Test_acc: 0.5000
Epoch: 91 loss: 0.6867 Train_acc: 0.6000
Epoch: 92 loss: 0.6866 Train_acc: 0.6000
Epoch: 93 loss: 0.6865 Train_acc: 0.6000
Epoch: 94 loss: 0.6864 Train_acc: 0.6000
Epoch: 95 loss: 0.6864 Train_acc: 0.6000
Epoch: 96 loss: 0.6863 Train_acc: 0.6000
Epoch: 97 loss: 0.6862 Train_acc: 0.6000
Epoch: 98 loss: 0.6862 Train_acc: 0.6000
Epoch: 99 loss: 0.6861 Train_acc: 0.6000
Epoch: 100 loss: 0.6860 Train_acc: 0.6000
Test_acc: 0.5000
[ ]: