Train a Cell Attention Network (CAN)#

We create and train a Cell Attention Network (CAN) originally proposed in Giusti et al. Cell Attention Networks (2022). The aim of this notebook is to be didactic and clear, for further technical and implementation details please refer to the original paper and the TopoModelX documentation.

Abstract:#

Since their introduction, graph attention networks achieved outstanding results in graph representation learning tasks. However, these networks consider only pairwise relationships among nodes and then they are not able to fully exploit higher-order interactions present in many real world data-sets. In this paper, we introduce Cell Attention Networks (CANs), a neural architecture operating on data defined over the vertices of a graph, representing the graph as the 1-skeleton of a cell complex introduced to capture higher order interactions. In particular, we exploit the lower and upper neighborhoods, as encoded in the cell complex, to design two independent masked self-attention mechanisms, thus generalizing the conventional graph attention strategy. The approach used in CANs is hierarchical and it incorporates the following steps: i) a lifting algorithm that learns edge features from node features; ii) a cell attention mechanism to find the optimal combination of edge features over both lower and upper neighbors; iii) a hierarchical edge pooling mechanism to extract a compact meaningful set of features.

98d25e90-4216-4d4d-975c-2baa3e388f1c

Remark. The notation we use is defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)and Hajij et al : Topological Deep Learning: Going Beyond Graph Data(2023). Custom symbols are introduced along the notebook, when necessary.

The Neural Network:#

The CAN layer, in the original paper, takes rank-\(0\) signals as input and gives rank-\(0\) signals as output (in general, it could take rank-\(r\) signals as input and give rank-\(r\) signals as output). The involved neighborhoods are: \(N = \{\mathcal N_1, \mathcal N_2\} = \{A_{\uparrow,r+1}, A_{\downarrow, r+1}\}\).

A CAN layer is made by the following 3 message passing stages:

  1. Attentional Lift (to compute \(r+1\)-signals from \(r\)-signals):

:nbsphinx-math:`begin{align*} &🟥textrm{ Message.} quad m_{(y,z) rightarrow x} &=& alpha(h_y^0,h_z^0) = \

&&=&Theta cdot (h_y^0||h_z^0)\

&🟦textrm{ Update.} quad h_x^1 &=& phi(h_x^0, m_{(y,z) rightarrow x}) end{align*}`

Where:

  • \(\alpha\) is a learnable function parameterized by \(\Theta\) \(\in\) \(\mathbb R^{2F_0 \times H}\). In the case of node signals as input, \(F_0\) is the number of nodes’ features and \(H\) is the number of heads as defined in the original paper.

  • \(||\) is the concatenation operator.

  • \(\phi\) is a learnable function that updates the features of a cell.

  1. (\(\times L\)) Attentional message passing at level \(r+1\). The general equation is given by:

\begin{align*} \textbf{h}_x^{t+1} = \phi^t \Bigg ( \textbf{h}_x^{t}, \bigotimes_{\mathcal{N}_k\in\mathcal N}\bigoplus_{y \in \mathcal{N}_k(x)} \alpha_k(h_x^t,h_y^t)\Bigg ) \end{align*}

In detail:

\begin{align*} &🟥\textrm{ Message.} &\quad m_{(y \rightarrow x),k} =& \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k \in \mathcal{N}\\ \\ &🟧 \textrm{ Within-Neighborhood Aggregation.} &\quad m_{x,k} =& \bigoplus_{y \in \mathcal{N}_k(x)} m_{(y \rightarrow x),k}\\ \\ &🟩 \textrm{ Between-Neighborhood Aggregation.} &\quad m_{x} =& \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}\\ \\ &🟦 \textrm{ Update.}&\quad h_x^{t+1} =& \phi^{t}(h_x^t, m_{x}) \end{align*}

Where:

  • \(\psi_k^t\) is a learnable function that computes the importance of a \(r+1\)-cell.

  • \(a_k^t: \mathbb R^{F^l}\times \mathbb R^{F^l} \to \mathbb R\) are learnable functions responsible for evaluating the reciprocal importance of two \(r+1\)-cells that share a common \((r)\)-cell or are parts of the same \((r+2)\)-cell.

  • \(\phi^t\) is a learnable function that updates the features of a cell.

  1. Attentional Pooling (performed after each message passing round of 2)):

:nbsphinx-math:`begin{align*} &🟥textrm{ Message.} quad m_{x} &=& gamma^t(h_x^t) =\

&&=& tau^t (a^tcdot h_x^t)\

&🟦textrm{ Update.} quad h_x^{t+1} &=& m_{x}h_x^t, forall xin mathcal C_r^{t+1} end{align*}`

Where:

  • \(\gamma^t\) is a learnable function that computes the attention coefficients (self-scores) as defined in the original paper.

  • \(\tau^t\) is a non-linear function, \(a\) are learnable parameters.

  • \(C^{t+1}_r\) is the set of rank-\(r\) cells of the coarse cell complex, defined keeping the rank-\(r\) cells corresponding to the top-K self-scores \(\gamma^t(h_x^t)\).

The Task:#

We train this model to perform entire complex classification on `MUTAG from the TUDataset <https://paperswithcode.com/dataset/mutag>`__. This dataset contains:

  • 188 samples of chemical compounds represented as graphs,

  • with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella Typhimurium. We use a “GAT-like” attention function following the approach from SAN. We implemented also a “GATv2-like” attention function.

Set-up#

[1]:
import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.cell.can import CAN
from topomodelx.utils.sparse import from_sparse

torch.manual_seed(0)
[1]:
<torch._C.Generator at 0x11971c5d0>

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

We import a subset of MUTAG, a benchmark dataset for graph classification.

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:

  • input signals x_0 and x_1 on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model’s inputs,

  • a binary classification label y associated to the cell complex.

[3]:
dataset = TUDataset(
    root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True, use_node_attr=True
)
dataset = dataset
cc_list = []
x_0_list = []
x_1_list = []
y_list = []
for graph in dataset:
    cell_complex = tnx.CellComplex(to_networkx(graph))
    cc_list.append(cell_complex)
    x_0_list.append(graph.x)
    x_1_list.append(graph.edge_attr)
    y_list.append(int(graph.y))
else:
    print(graph)

i_cc = 0
print(f"Features on nodes for the {i_cc}th cell complex: {x_0_list[i_cc].shape}.")
print(f"Features on edges for the {i_cc}th cell complex: {x_1_list[i_cc].shape}.")
print(f"Label of {i_cc}th cell complex: {y_list[i_cc]}.")
Data(edge_index=[2, 36], x=[16, 7], edge_attr=[36, 4], y=[1])
Features on nodes for the 0th cell complex: torch.Size([17, 7]).
Features on edges for the 0th cell complex: torch.Size([38, 4]).
Label of 0th cell complex: 1.

Implementing CAN will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages.

We need the matrices \(A_{\downarrow, 1}\) and \(A_{\uparrow, 1}\).

[4]:
down_laplacian_list = []
up_laplacian_list = []
adjacency_0_list = []

for cell_complex in cc_list:
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    adjacency_0_list.append(adjacency_0)

    down_laplacian_t = cell_complex.down_laplacian_matrix(rank=1)
    down_laplacian_t = from_sparse(down_laplacian_t)
    down_laplacian_list.append(down_laplacian_t)

    try:
        up_laplacian_t = cell_complex.up_laplacian_matrix(rank=1)
        up_laplacian_t = from_sparse(up_laplacian_t)
    except ValueError:
        up_laplacian_t = np.zeros(
            (down_laplacian_t.shape[0], down_laplacian_t.shape[0])
        )
        up_laplacian_t = torch.from_numpy(up_laplacian_t).to_sparse()

    up_laplacian_list.append(up_laplacian_t)

Create the Neural Network#

Using the CANLayer class, we create a neural network with stacked layers.

[5]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        out_channels,
        num_classes,
        dropout=0.5,
        heads=2,
        n_layers=2,
        att_lift=True,
    ):
        super().__init__()
        self.base_model = CAN(
            in_channels_0,
            in_channels_1,
            out_channels,
            dropout=dropout,
            heads=heads,
            n_layers=n_layers,
            att_lift=att_lift,
        )
        self.lin_0 = torch.nn.Linear(out_channels, 128)
        self.lin_1 = torch.nn.Linear(128, num_classes)

    def forward(self, x_0, x_1, adjacency, down_laplacian, up_laplacian):
        x = self.base_model(x_0, x_1, adjacency, down_laplacian, up_laplacian)
        # max pooling over edges in each graph
        x = x.max(dim=0)[0]
        # Feed-Foward Neural Network to predict the graph label
        out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))
        return torch.sigmoid(out)
[6]:
in_channels_0 = x_0_list[0].shape[-1]
in_channels_1 = x_1_list[0].shape[-1]
out_channels = 32
num_classes = 2
heads = 2
n_layers = 2

model = Network(
    in_channels_0,
    in_channels_1,
    out_channels,
    num_classes,
    dropout=0.5,
    heads=heads,
    n_layers=n_layers,
    att_lift=True,
)
model = model.to(device)

Train the Neural Network#

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

[7]:
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
model
[7]:
Network(
  (base_model): CAN(
    (lift_layer): MultiHeadLiftLayer(
      (lifts): LiftLayer()
    )
    (layers): ModuleList(
      (0): CANLayer(
        (lower_att): MultiHeadCellAttention(
          (att_activation): LeakyReLU(negative_slope=0.2)
          (lin): Linear(in_features=11, out_features=32, bias=False)
        )
        (upper_att): MultiHeadCellAttention(
          (att_activation): LeakyReLU(negative_slope=0.2)
          (lin): Linear(in_features=11, out_features=32, bias=False)
        )
        (lin): Linear(in_features=11, out_features=32, bias=False)
        (aggregation): Aggregation()
      )
      (1): CANLayer(
        (lower_att): MultiHeadCellAttention(
          (att_activation): LeakyReLU(negative_slope=0.2)
          (lin): Linear(in_features=32, out_features=32, bias=False)
        )
        (upper_att): MultiHeadCellAttention(
          (att_activation): LeakyReLU(negative_slope=0.2)
          (lin): Linear(in_features=32, out_features=32, bias=False)
        )
        (lin): Linear(in_features=32, out_features=32, bias=False)
        (aggregation): Aggregation()
      )
    )
  )
  (lin_0): Linear(in_features=32, out_features=128, bias=True)
  (lin_1): Linear(in_features=128, out_features=2, bias=True)
)

We split the dataset into train and test sets.

[8]:
test_size = 0.3
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)
down_laplacian_train, down_laplacian_test = train_test_split(
    down_laplacian_list, test_size=test_size, shuffle=False
)
up_laplacian_train, up_laplacian_test = train_test_split(
    up_laplacian_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

[9]:
test_interval = 1
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, adjacency, down_laplacian, up_laplacian, y in zip(
        x_0_train,
        x_1_train,
        adjacency_0_train,
        down_laplacian_train,
        up_laplacian_train,
        y_train,
        strict=True,
    ):
        x_0 = x_0.float().to(device)
        x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
        adjacency = adjacency.float().to(device)
        down_laplacian, up_laplacian = (
            down_laplacian.float().to(device),
            up_laplacian.float().to(device),
        )
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency, down_laplacian, up_laplacian)
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, adjacency, down_laplacian, up_laplacian, y in zip(
                x_0_test,
                x_1_test,
                adjacency_0_test,
                down_laplacian_test,
                up_laplacian_test,
                y_test,
                strict=True,
            ):
                x_0 = x_0.float().to(device)
                x_1, y = (
                    x_1.float().to(device),
                    torch.tensor(y, dtype=torch.long).to(device),
                )
                adjacency = adjacency.float().to(device)
                down_laplacian, up_laplacian = (
                    down_laplacian.float().to(device),
                    up_laplacian.float().to(device),
                )
                y_hat = model(x_0, x_1, adjacency, down_laplacian, up_laplacian)
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)
Epoch: 1 loss: 0.6330 Train_acc: 0.6718
Test_acc: 0.5965
Epoch: 2 loss: 0.6116 Train_acc: 0.6947
Test_acc: 0.5965
Epoch: 3 loss: 0.6071 Train_acc: 0.6947
Test_acc: 0.5965
Epoch: 4 loss: 0.6027 Train_acc: 0.6947
Test_acc: 0.5965
Epoch: 5 loss: 0.5974 Train_acc: 0.7099
Test_acc: 0.6491
Epoch: 6 loss: 0.5911 Train_acc: 0.7252
Test_acc: 0.6491
Epoch: 7 loss: 0.5979 Train_acc: 0.7176
Test_acc: 0.6140
Epoch: 8 loss: 0.5826 Train_acc: 0.7252
Test_acc: 0.6316
Epoch: 9 loss: 0.5908 Train_acc: 0.7252
Test_acc: 0.6316
Epoch: 10 loss: 0.5839 Train_acc: 0.7252
Test_acc: 0.6316
[ ]: