Train a Simplicial Complex Convolutional Network (SCCN)#
We create a SCCN model a la Yang et al : Efficient Representation Learning for Higher-Order Data with Simplicial Complexes (LoG 2022)
We train the model to perform binary node classification using the KarateClub benchmark dataset.
The model operates on cells of all ranks up to some max rank \(r_\mathrm{max}\). The equations of one layer of this neural network are given by:
🟥 \(\quad m_{{y \rightarrow x}}^{(r \rightarrow r)} = (H_{r})_{xy} \cdot h^{t,(r)}_y \cdot \Theta^{t,(r\to r)}\), (for \(0\leq r \leq r_\mathrm{max}\))
🟥 \(\quad m_{{y \rightarrow x}}^{(r-1 \rightarrow r)} = (B_{r}^T)_{xy} \cdot h^{t,(r-1)}_y \cdot \Theta^{t,(r-1\to r)}\), (for \(1\leq r \leq r_\mathrm{max}\))
🟥 \(\quad m_{{y \rightarrow x}}^{(r+1 \rightarrow r)} = (B_{r+1})_{xy} \cdot h^{t,(r+1)}_y \cdot \Theta^{t,(r+1\to r)}\), (for \(0\leq r \leq r_\mathrm{max}-1\))
🟧 \(\quad m_{x}^{(r \rightarrow r)} = \sum_{y \in \mathcal{L}_\downarrow(x)\bigcup \mathcal{L}_\uparrow(x)} m_{y \rightarrow x}^{(r \rightarrow r)}\)
🟧 \(\quad m_{x}^{(r-1 \rightarrow r)} = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r-1 \rightarrow r)}\)
🟧 \(\quad m_{x}^{(r+1 \rightarrow r)} = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r+1 \rightarrow r)}\)
🟩 \(\quad m_x^{(r)} = m_x^{(r \rightarrow r)} + m_x^{(r-1 \rightarrow r)} + m_x^{(r+1 \rightarrow r)}\)
🟦 \(\quad h_x^{t+1,(r)} = \sigma(m_x^{(r)})\)
Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).
[1]:
import numpy as np
import toponetx as tnx
import torch
from topomodelx.nn.simplicial.sccn import SCCN
from topomodelx.utils.sparse import from_sparse
%load_ext autoreload
%autoreload 2
Pre-processing#
Import dataset#
The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.
We must first lift our graph dataset into the simplicial complex domain.
Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape \(n_\text{nodes} \times\) in_channels, where in_channels is the dimension of each cell’s feature. The feature dimension is feat_dim
.
[41]:
dataset = tnx.datasets.karate_club(complex_type="simplicial", feat_dim=2)
print(dataset)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
Define neighborhood structures.#
Our implementation allows for features on cells up to an arbitrary maximum rank. In this dataset, we can use at most max_rank = 3
, which is what we choose.
We define incidence and adjacency matrices up to the max rank and put them in dictionaries indexed by the rank, as is expected by the SCCNLayer
. The form of tha adjacency and incidence matrices could be chosen arbitrarily, here we follow the original formulation by Yang et al. quite closely and select the adjacencies as r-Hodge Laplacians \(H_r\), summed with \(2I\) (or just \(I\) for \(r\in\{0, r_\mathrm{max}\}\)) to allow cells to pass messages to themselves. The incidence
matrices are the usual boundary matrices \(B_r\). One could additionally weight/normalize these matrices as suggested by Yang et al., but we refrain from doing this for simplicity.
[42]:
max_rank = 3 # There are features up to tetrahedron order in the dataset
[43]:
def sparse_to_torch(X):
return from_sparse(X)
incidences = {
f"rank_{r}": sparse_to_torch(dataset.incidence_matrix(rank=r))
for r in range(1, max_rank + 1)
}
adjacencies = {}
adjacencies["rank_0"] = (
sparse_to_torch(dataset.adjacency_matrix(rank=0))
+ torch.eye(dataset.shape[0]).to_sparse()
)
for r in range(1, max_rank):
adjacencies[f"rank_{r}"] = (
sparse_to_torch(
dataset.adjacency_matrix(rank=r) + dataset.coadjacency_matrix(rank=r)
)
+ 2 * torch.eye(dataset.shape[r]).to_sparse()
)
adjacencies[f"rank_{max_rank}"] = (
sparse_to_torch(dataset.coadjacency_matrix(rank=max_rank))
+ torch.eye(dataset.shape[max_rank]).to_sparse()
)
for r in range(max_rank + 1):
print(f"The adjacency matrix H{r} has shape: {adjacencies[f'rank_{r}'].shape}.")
if r > 0:
print(f"The incidence matrix B{r} has shape: {incidences[f'rank_{r}'].shape}.")
The adjacency matrix H0 has shape: torch.Size([34, 34]).
The adjacency matrix H1 has shape: torch.Size([78, 78]).
The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix H2 has shape: torch.Size([45, 45]).
The incidence matrix B2 has shape: torch.Size([78, 45]).
The adjacency matrix H3 has shape: torch.Size([11, 11]).
The incidence matrix B3 has shape: torch.Size([45, 11]).
/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/scipy/sparse/_index.py:143: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.
self._set_arrayXarray(i, j, x)
Import signal#
We import the features at each rank.
[44]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
[45]:
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
There are 34 nodes with features of dimension 2.
Load edge features.
[46]:
x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = torch.tensor(np.stack(x_1))
[47]:
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
There are 78 edges with features of dimension 2.
Similarly for face features:
[48]:
x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = torch.tensor(np.stack(x_2))
[49]:
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 45 faces with features of dimension 2.
Higher order features:
[50]:
x_3 = list(dataset.get_simplex_attributes("tetrahedron_feat").values())
x_3 = torch.tensor(np.stack(x_3))
[51]:
print(
f"There are {x_3.shape[0]} tetrahedrons with features of dimension {x_3.shape[1]}."
)
There are 11 tetrahedrons with features of dimension 2.
The features are organized in a dictionary keeping track of their rank, similar to the adjacencies/incidences earlier.
[52]:
features = {"rank_0": x_0, "rank_1": x_1, "rank_2": x_2, "rank_3": x_3}
Define binary labels#
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.
We keep the last four nodes’ true labels for the purpose of testing.
[53]:
y = np.array(
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
1,
1,
1,
1,
0,
0,
1,
1,
0,
1,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)
Create the Neural Network#
Using the SAN class, we create our neural network with stacked layers. Given the considered dataset and task (Karate Club, node classification), a linear layer at the end produces an output with shape \(n_\text{nodes} \times 2\), so we can compare with our binary labels.
[54]:
class Network(torch.nn.Module):
def __init__(
self, channels, out_channels, max_rank, n_layers=2, update_func="sigmoid"
):
super().__init__()
self.base_model = SCCN(
channels=channels,
max_rank=max_rank,
n_layers=n_layers,
update_func=update_func,
)
self.linear = torch.nn.Linear(channels, out_channels)
def forward(self, features, incidences, adjacencies):
features = self.base_model(features, incidences, adjacencies)
x = self.linear(features["rank_0"])
return torch.softmax(x, dim=1)
[55]:
n_layers = 2
out_channels = 2
model = Network(
channels=channels_nodes,
out_channels=out_channels,
max_rank=max_rank,
n_layers=n_layers,
update_func="sigmoid",
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Train the Neural Network#
The following cell performs the training, looping over the network for a low number of epochs. Test accuracy is more arbitrary between runs, likely due to the small dataset set size.
[57]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
epoch_loss = []
model.train()
optimizer.zero_grad()
y_hat = model(features, incidences, adjacencies)
loss = torch.nn.functional.binary_cross_entropy_with_logits(
y_hat[: len(y_train)].float(), y_train.float()
)
epoch_loss.append(loss.item())
loss.backward()
optimizer.step()
y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
print(
f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
flush=True,
)
if epoch_i % test_interval == 0:
with torch.no_grad():
y_hat_test = model(features, incidences, adjacencies)
y_pred_test = torch.where(
y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
)
# _pred_test = torch.softmax(y_hat_test,dim=1).ge(0.5).float()
test_accuracy = (
torch.eq(y_pred_test[-len(y_test) :], y_test)
.all(dim=1)
.float()
.mean()
.item()
)
print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7267 Train_acc: 0.4333
Epoch: 2 loss: 0.7176 Train_acc: 0.7000
Epoch: 3 loss: 0.7089 Train_acc: 0.5000
Epoch: 4 loss: 0.7013 Train_acc: 0.5667
Epoch: 5 loss: 0.6953 Train_acc: 0.5667
Epoch: 6 loss: 0.6906 Train_acc: 0.5667
Epoch: 7 loss: 0.6867 Train_acc: 0.5667
Epoch: 8 loss: 0.6829 Train_acc: 0.5667
Epoch: 9 loss: 0.6789 Train_acc: 0.5667
Epoch: 10 loss: 0.6745 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 11 loss: 0.6696 Train_acc: 0.5667
Epoch: 12 loss: 0.6645 Train_acc: 0.5667
Epoch: 13 loss: 0.6592 Train_acc: 0.5667
Epoch: 14 loss: 0.6538 Train_acc: 0.4667
Epoch: 15 loss: 0.6486 Train_acc: 0.8000
Epoch: 16 loss: 0.6437 Train_acc: 0.8000
Epoch: 17 loss: 0.6392 Train_acc: 0.8000
Epoch: 18 loss: 0.6351 Train_acc: 0.8000
Epoch: 19 loss: 0.6316 Train_acc: 0.8000
Epoch: 20 loss: 0.6286 Train_acc: 0.7667
Test_acc: 0.7500
Epoch: 21 loss: 0.6260 Train_acc: 0.7667
Epoch: 22 loss: 0.6238 Train_acc: 0.7667
Epoch: 23 loss: 0.6220 Train_acc: 0.7667
Epoch: 24 loss: 0.6203 Train_acc: 0.7667
Epoch: 25 loss: 0.6188 Train_acc: 0.7667
Epoch: 26 loss: 0.6173 Train_acc: 0.7667
Epoch: 27 loss: 0.6158 Train_acc: 0.7667
Epoch: 28 loss: 0.6142 Train_acc: 0.7667
Epoch: 29 loss: 0.6125 Train_acc: 0.7667
Epoch: 30 loss: 0.6106 Train_acc: 0.8000
Test_acc: 0.7500
Epoch: 31 loss: 0.6084 Train_acc: 0.8000
Epoch: 32 loss: 0.6061 Train_acc: 0.8000
Epoch: 33 loss: 0.6035 Train_acc: 0.8333
Epoch: 34 loss: 0.6007 Train_acc: 0.8333
Epoch: 35 loss: 0.5975 Train_acc: 0.8333
Epoch: 36 loss: 0.5942 Train_acc: 0.8333
Epoch: 37 loss: 0.5905 Train_acc: 0.8667
Epoch: 38 loss: 0.5860 Train_acc: 0.8667
Epoch: 39 loss: 0.5706 Train_acc: 0.8667
Epoch: 40 loss: 0.5554 Train_acc: 0.9333
Test_acc: 0.7500
Epoch: 41 loss: 0.5529 Train_acc: 0.9333
Epoch: 42 loss: 0.5500 Train_acc: 0.9333
Epoch: 43 loss: 0.5448 Train_acc: 0.9667
Epoch: 44 loss: 0.5443 Train_acc: 0.9667
Epoch: 45 loss: 0.5422 Train_acc: 0.9667
Epoch: 46 loss: 0.5390 Train_acc: 0.9667
Epoch: 47 loss: 0.5368 Train_acc: 0.9667
Epoch: 48 loss: 0.5358 Train_acc: 0.9667
Epoch: 49 loss: 0.5351 Train_acc: 0.9667
Epoch: 50 loss: 0.5342 Train_acc: 0.9667
Test_acc: 0.7500
Epoch: 51 loss: 0.5329 Train_acc: 0.9667
Epoch: 52 loss: 0.5316 Train_acc: 0.9667
Epoch: 53 loss: 0.5305 Train_acc: 0.9667
Epoch: 54 loss: 0.5298 Train_acc: 0.9667
Epoch: 55 loss: 0.5292 Train_acc: 0.9667
Epoch: 56 loss: 0.5287 Train_acc: 0.9667
Epoch: 57 loss: 0.5283 Train_acc: 0.9667
Epoch: 58 loss: 0.5278 Train_acc: 0.9667
Epoch: 59 loss: 0.5274 Train_acc: 0.9667
Epoch: 60 loss: 0.5269 Train_acc: 0.9667
Test_acc: 0.5000
Epoch: 61 loss: 0.5264 Train_acc: 0.9667
Epoch: 62 loss: 0.5260 Train_acc: 0.9667
Epoch: 63 loss: 0.5256 Train_acc: 0.9667
Epoch: 64 loss: 0.5252 Train_acc: 0.9667
Epoch: 65 loss: 0.5249 Train_acc: 0.9667
Epoch: 66 loss: 0.5246 Train_acc: 0.9667
Epoch: 67 loss: 0.5244 Train_acc: 0.9667
Epoch: 68 loss: 0.5242 Train_acc: 0.9667
Epoch: 69 loss: 0.5240 Train_acc: 0.9667
Epoch: 70 loss: 0.5238 Train_acc: 0.9667
Test_acc: 0.7500
Epoch: 71 loss: 0.5236 Train_acc: 0.9667
Epoch: 72 loss: 0.5235 Train_acc: 0.9667
Epoch: 73 loss: 0.5234 Train_acc: 0.9667
Epoch: 74 loss: 0.5232 Train_acc: 0.9667
Epoch: 75 loss: 0.5231 Train_acc: 0.9667
Epoch: 76 loss: 0.5230 Train_acc: 0.9667
Epoch: 77 loss: 0.5229 Train_acc: 0.9667
Epoch: 78 loss: 0.5228 Train_acc: 0.9667
Epoch: 79 loss: 0.5227 Train_acc: 0.9667
Epoch: 80 loss: 0.5226 Train_acc: 0.9667
Test_acc: 0.7500
Epoch: 81 loss: 0.5225 Train_acc: 0.9667
Epoch: 82 loss: 0.5225 Train_acc: 0.9667
Epoch: 83 loss: 0.5224 Train_acc: 0.9667
Epoch: 84 loss: 0.5223 Train_acc: 0.9667
Epoch: 85 loss: 0.5222 Train_acc: 0.9667
Epoch: 86 loss: 0.5222 Train_acc: 0.9667
Epoch: 87 loss: 0.5221 Train_acc: 0.9667
Epoch: 88 loss: 0.5221 Train_acc: 0.9667
Epoch: 89 loss: 0.5220 Train_acc: 0.9667
Epoch: 90 loss: 0.5220 Train_acc: 0.9667
Test_acc: 0.7500
Epoch: 91 loss: 0.5219 Train_acc: 0.9667
Epoch: 92 loss: 0.5219 Train_acc: 0.9667
Epoch: 93 loss: 0.5218 Train_acc: 0.9667
Epoch: 94 loss: 0.5218 Train_acc: 0.9667
Epoch: 95 loss: 0.5218 Train_acc: 0.9667
Epoch: 96 loss: 0.5217 Train_acc: 0.9667
Epoch: 97 loss: 0.5217 Train_acc: 0.9667
Epoch: 98 loss: 0.5216 Train_acc: 0.9667
Epoch: 99 loss: 0.5216 Train_acc: 0.9667
Epoch: 100 loss: 0.5216 Train_acc: 0.9667
Test_acc: 0.7500
[ ]: