Train a Simplicial Attention Network (SAN)#

We create and train a Simplicial Attention Neural Networks (SAN) originally proposed in Giusti, Battiloro et. al : Simplicial Attention Neural Networks (2022). The aim of this notebook is to be didactic and clear, for further technical and implementation details please refer to the original paper and the TopoModelX documentation.

Abstract#

The aim of this work is to introduce simplicial attention networks (SANs), i.e., novel neural architectures that operate on data defined on simplicial complexes leveraging masked self-attentional layers. Hinging on formal arguments from topological signal processing, we introduce a proper self-attention mechanism able to process data components at different layers (e.g., nodes, edges, triangles, and so on), while learning how to weight both upper and lower neighborhoods of the given topological domain in a totally task-oriented fashion. The proposed SANs generalize most of the current architectures available for processing data defined on simplicial complexes.

SAN-architecture

Remark. The notation we use is defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)and Hajij et al : Topological Deep Learning: Going Beyond Graph Data(2023). Custom symbols are introduced along the notebook, when necessary.

The Neural Network#

The SAN layer takes rank-\(r\) signals as input and gives rank-\(r\) signals as output. The involved neighborhoods are:

\begin{equation} \mathcal N = \{\mathcal N_1, \mathcal N_2,...,\mathcal N_{2p+1}\} = \{A_{\uparrow, r}, A_{\downarrow, r}, A_{\uparrow, r}^2, A_{\downarrow, r}^2,...,A_{\uparrow, r}^p, A_{\downarrow, r}^p, Q_r\}, \end{equation} where \(Q_r\) is a sparse projection operator (weighted matrix) over the kernel of the \(r\)-th Hodge Laplacian \(L_r\), computed as in the original paper. \(Q_r\) has the same topology of \(L_r\).

The equation of the SAN layer of this neural network is given by:

\begin{equation} \textbf{h}_x^{t+1} = \phi^l \Bigg ( \textbf{h}_x^{t}, \bigotimes_{\mathcal{N}_k\in\mathcal N}\bigoplus_{y \in \mathcal{N}_k(x)} \widetilde{\alpha}_k(h_x^t,hy^t)\Bigg ), \end{equation}

with \(\widetilde{\alpha}_k\) being either an attention function \(\alpha_k\) if \(\mathcal{N}_k \neq Q_r\) or a standard convolution term(affine transformation + weights) with weights given by the entries of \(Q_r\) if \(\mathcal{N}_k = Q_r\).

Therefore, the SAN layer is made by an attentional convolution from rank-\(r\) cells to rank-\(r\) cells using an adjacency message passing scheme up to \(p\)-hops neighborhoods:

\begin{align*} &🟥\textrm{ Message.} &\quad m_{(y \rightarrow x),k} =& \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k \in \mathcal{N}\\ \\ &🟧 \textrm{ Within-Neighborhood Aggregation.} &\quad m_{x,k} =& \bigoplus_{y \in \mathcal{N}_k(x)} m_{(y \rightarrow x),k}\\ \\ &🟩 \textrm{ Between-Neighborhood Aggregation.} &\quad m_{x} =& \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}\\ \\ &🟦 \textrm{ Update.}&\quad h_x^{t+1} =& \phi^{t}(h_x^t, m_{x}) \end{align*}

The Task:#

We train this model to perform a binary node classification task using KarateClub dataset. We use a “GAT-like” attention function, in which two different sets of attention weights \(a_\uparrow\) and \(a_\downarrow\) are learned for the upper neighborhoods \(A_{\uparrow,1}^p\) and for the lower neighborhoods \(A_{\downarrow,1}^p\) (\(p=1,...,P\)), respectively, i.e.:

  • If \(\mathcal{N}_k \neq Q_r\) and suppose, as an example, \(\mathcal{N}_k = A_{\downarrow,1}^g\), the \(g\)-hops lower neighborhood: \begin{align} &a_k(h_x^{t}, h_y^{t}) = (\textrm{softmax}_j(\textrm{LeakyReLU}(a_{\downarrow}^T[\underset{p=1}{\overset{P}{||}}h_x^{t}W_{\downarrow,p}|| \underset{p=1}{\overset{P}{||}}h_y^{t}W_{\downarrow,p}]))^g\\ & \psi_k^t(h_x^{t}) = h_x^{t}W_{\downarrow,g}. \end{align}

  • If \(\mathcal{N}_k = Q_r\): \begin{align} &a_k(h_x^{t}, h_y^{t}) = Q_{x,y}\\ & \psi_k^t(h_x^{t}) = h_x^{t}W. \end{align}

\(W\), \(a_\downarrow\), \(a_\uparrow\), {\(W_{\downarrow,p}\}_{p=1}^P\) and \(\{W_{\uparrow,p}\}_{p=1}^P\) are learnable weights.

[1]:
import numpy as np
import toponetx as tnx
import torch

from topomodelx.nn.simplicial.san import SAN
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2
 No module named 'igraph'. If you need to use hypernetx.algorithms.hypergraph_modularity, please install additional packages by running the following command: pip install .['all']
[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

[3]:
dataset = tnx.datasets.karate_club(complex_type="simplicial")
print(dataset)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
[4]:
dataset.shape
[4]:
(34, 78, 45, 11, 2)

We now retrieve the neighborhoods (i.e. their representative matrices) that we will use to send messages on the domain. In this case, we decide w.l.o.g. to work at the edge level (thus considering a simplicial complex of order 2). We therefore need the lower and upper laplacians of rank 1, \(L_{\downarrow,1}=B_1^TB_1\) and \(L_{\uparrow,1}=B_2B_2^T\), both with dimensions \(n_\text{edges} \times n_\text{edges}\), where \(B_1\) and \(B_2\) are the incidence matrices of rank 1 and 2. Please notice that the binary adjacencies \(A_{\downarrow,1}^p\) and \(A_{\uparrow,1}^p\) encoding the \(p\)-hops neighborhoods are given by the support (the non-zeros pattern) of \(L_{\downarrow,1}^p\) and \(L_{\uparrow,1}^p\), respectively. We also convert the neighborhood structures to torch tensors.

Remark. In the case of rank-0 simplices (nodes), there is no lower Laplacian; in this case, we just initialize the down laplacian as a 0-matrix, and SAN automatically becomes a GAT-like architecture. In the case of simplices of maxium rank (the order of the complex), there is no upper Laplacian. In this case we can also initialize it as a 0 matrix and SAN will only consider the lower adjacencies.

[5]:
simplex_order_k = 1
# Down laplacian
try:
    laplacian_down = from_sparse(dataset.down_laplacian_matrix(rank=simplex_order_k))
except ValueError:
    laplacian_down = torch.zeros(
        (dataset.shape[simplex_order_k], dataset.shape[simplex_order_k])
    ).to_sparse()
# Up laplacian
try:
    laplacian_up = from_sparse(dataset.up_laplacian_matrix(rank=simplex_order_k))
except ValueError:
    laplacian_up = torch.zeros(
        (dataset.shape[simplex_order_k], dataset.shape[simplex_order_k])
    ).to_sparse()

We define edge features to be the gradient of the nodes features, i.e. given the node feature matrix \(X_0\), we compute the edge features matrix as \(X_1 = B_1^TX_0\). We will finally obtain the estimated node labels from the updated edge features by multiplying them again with \(B_1\), i.e. the final nodes features are computed as the divergence of the final edge features.

Remark. Please notice that also this way of deriving edges/nodes features from nodes/edges features could be seen as a (non-learnable) message passing between rank-0/1 cells (nodes/edges) and rank-1/0 cells (nodes).

[6]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = torch.tensor(np.stack(x_2))
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.

We use the incidence matrix between nodes-edges:

[7]:
incidence_0_1 = from_sparse(dataset.incidence_matrix(1))

The final edge features are obtained summing the original features of those edges plus the projection of the node features onto edges (using the incidence matrix accordingly):

[8]:
x = x_1 + torch.sparse.mm(incidence_0_1.T, x_0)

Hence, the final input features are defined by this sum, and we also pre-define the number of hidden and output channels of the model.

[9]:
in_channels = x.shape[-1]
hidden_channels = 16
out_channels = 2

We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert one-hot encode the binary labels, and keep the first four nodes for the purpose of testing.

[10]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

Create the Neural Network#

Using the SAN class, we create our neural network with stacked layers. Given the considered dataset and task (Karate Club, node classification), a linear layer at the end produces an output with shape \(n_\text{nodes} \times 2\), so we can compare with our binary labels.

[11]:
class Network(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, n_layers=1):
        super().__init__()
        self.base_model = SAN(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            n_layers=n_layers,
        )
        self.linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, laplacian_up, laplacian_down):
        x = self.base_model(x, laplacian_up, laplacian_down)
        x = self.linear(x)
        return torch.sigmoid(x)
[12]:
laplacian_up.shape, laplacian_down.shape, x_1.shape
[12]:
(torch.Size([78, 78]), torch.Size([78, 78]), torch.Size([78, 2]))
[13]:
n_layers = 1
model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=n_layers,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model
[13]:
Network(
  (base_model): SAN(
    (layers): ModuleList(
      (0): SANLayer(
        (conv_down): SANConv()
        (conv_up): SANConv()
        (conv_harmonic): Conv()
      )
    )
  )
  (linear): Linear(in_features=16, out_features=2, bias=True)
)

Train the Neural Network#

The following cell performs the training, looping over the network for a low number of epochs.

[14]:
test_interval = 10
num_epochs = 50
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat_edge = model(x, laplacian_up=laplacian_up, laplacian_down=laplacian_down)
    # We project the edge-level output of the model to the node-level
    # and apply softmax fn to get the final node-level classification output
    y_hat = torch.softmax(torch.sparse.mm(incidence_0_1, y_hat_edge), dim=1)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_edge_test = model(
                x, laplacian_up=laplacian_up, laplacian_down=laplacian_down
            )
            # Projection to node-level
            y_hat_test = torch.softmax(
                torch.sparse.mm(incidence_0_1, y_hat_edge_test), dim=1
            )
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7247 Train_acc: 0.3000
Epoch: 2 loss: 0.7226 Train_acc: 0.6667
Epoch: 3 loss: 0.7203 Train_acc: 0.6667
Epoch: 4 loss: 0.7167 Train_acc: 0.6667
Epoch: 5 loss: 0.7115 Train_acc: 0.6667
Epoch: 6 loss: 0.7057 Train_acc: 0.7000
Epoch: 7 loss: 0.6999 Train_acc: 0.7000
Epoch: 8 loss: 0.6931 Train_acc: 0.7000
Epoch: 9 loss: 0.6874 Train_acc: 0.7000
Epoch: 10 loss: 0.6814 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 11 loss: 0.6760 Train_acc: 0.7333
Epoch: 12 loss: 0.6716 Train_acc: 0.7333
Epoch: 13 loss: 0.6669 Train_acc: 0.7333
Epoch: 14 loss: 0.6624 Train_acc: 0.7000
Epoch: 15 loss: 0.6586 Train_acc: 0.7000
Epoch: 16 loss: 0.6551 Train_acc: 0.7667
Epoch: 17 loss: 0.6528 Train_acc: 0.7667
Epoch: 18 loss: 0.6508 Train_acc: 0.7667
Epoch: 19 loss: 0.6491 Train_acc: 0.7667
Epoch: 20 loss: 0.6479 Train_acc: 0.7667
Test_acc: 0.2500
Epoch: 21 loss: 0.6463 Train_acc: 0.7667
Epoch: 22 loss: 0.6452 Train_acc: 0.7667
Epoch: 23 loss: 0.6438 Train_acc: 0.7333
Epoch: 24 loss: 0.6433 Train_acc: 0.7333
Epoch: 25 loss: 0.6419 Train_acc: 0.7667
Epoch: 26 loss: 0.6412 Train_acc: 0.7667
Epoch: 27 loss: 0.6397 Train_acc: 0.7667
Epoch: 28 loss: 0.6391 Train_acc: 0.7667
Epoch: 29 loss: 0.6379 Train_acc: 0.7667
Epoch: 30 loss: 0.6369 Train_acc: 0.7667
Test_acc: 0.2500
Epoch: 31 loss: 0.6360 Train_acc: 0.7667
Epoch: 32 loss: 0.6347 Train_acc: 0.7667
Epoch: 33 loss: 0.6333 Train_acc: 0.7667
Epoch: 34 loss: 0.6317 Train_acc: 0.7667
Epoch: 35 loss: 0.6298 Train_acc: 0.7667
Epoch: 36 loss: 0.6282 Train_acc: 0.7667
Epoch: 37 loss: 0.6272 Train_acc: 0.7667
Epoch: 38 loss: 0.6267 Train_acc: 0.8000
Epoch: 39 loss: 0.6265 Train_acc: 0.8000
Epoch: 40 loss: 0.6262 Train_acc: 0.8000
Test_acc: 0.2500
Epoch: 41 loss: 0.6260 Train_acc: 0.8000
Epoch: 42 loss: 0.6259 Train_acc: 0.8000
Epoch: 43 loss: 0.6260 Train_acc: 0.8000
Epoch: 44 loss: 0.6261 Train_acc: 0.7667
Epoch: 45 loss: 0.6260 Train_acc: 0.7667
Epoch: 46 loss: 0.6258 Train_acc: 0.8000
Epoch: 47 loss: 0.6255 Train_acc: 0.8000
Epoch: 48 loss: 0.6252 Train_acc: 0.8000
Epoch: 49 loss: 0.6249 Train_acc: 0.8000
Epoch: 50 loss: 0.6245 Train_acc: 0.8000
Test_acc: 0.2500
[ ]: