Train a Convolutional Cell Complex Network (CCXN)#

We create and train a simplified version of the CCXN originally proposed in Hajij et. al : Cell Complex Neural Networks (2020).

The Neural Network:#

The equations of one layer of this neural network are given by:

  1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 \(\quad m_{y \rightarrow \{z\} \rightarrow x}^{(0 \rightarrow 1 \rightarrow 0)} = M_{\mathcal{L}_\uparrow}^t(h_x^{t,(0)}, h_y^{t,(0)}, \Theta^{t,(y \rightarrow x)})\)

🟧 \(\quad m_x^{(0 \rightarrow 1 \rightarrow 0)} = AGG_{y \in \mathcal{L}_\uparrow(x)}(m_{y \rightarrow \{z\} \rightarrow x}^{0 \rightarrow 1 \rightarrow 0})\)

🟩 \(\quad m_x^{(0)} = m_x^{(0 \rightarrow 1 \rightarrow 0)}\)

🟦 \(\quad h_x^{t+1,(0)} = U^{t}(h_x^{t,(0)}, m_x^{(0)})\)

  1. A convolution from edges to faces using a cohomology message passing scheme:

🟥 \(\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)\)

🟧 \(\quad m_x^{(r' \rightarrow r)} = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}\)

🟩 \(\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}\)

🟦 \(\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})\)

Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).

The Task:#

We train this model to perform entire complex classification on a small version of shrec16.

Set-up#

[1]:
import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split

from topomodelx.nn.cell.ccxn import CCXN
from topomodelx.utils.sparse import from_sparse

torch.manual_seed(0)

%load_ext autoreload
%autoreload 2

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

The first step is to import the dataset, shrec16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a cell complex.

We also retrieve: - input signals x_0 and x_1 on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model’s inputs, - a scalar classification label y associated to the cell complex.

[3]:
shrec, _ = tnx.datasets.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...

done!
[4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)
The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.

We lift each simplicial complex into a cell complex.

Then, we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each cell complex. In the case of this architecture, we need the adjacency matrix \(A_{\uparrow, 0}\) and the coboundary matrix \(B_2^T\) of each cell complex.

[5]:
cc_list = []
incidence_2_t_list = []
adjacency_0_list = []
for simplex in simplexes:
    cell_complex = simplex.to_cell_complex()
    cc_list.append(cell_complex)

    incidence_2_t = cell_complex.incidence_matrix(rank=2).T
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    incidence_2_t = from_sparse(incidence_2_t)
    adjacency_0 = from_sparse(adjacency_0)
    incidence_2_t_list.append(incidence_2_t)
    adjacency_0_list.append(adjacency_0)
[6]:
i_complex = 6
print(
    f"The {i_complex}th cell complex has an incidence_2_t matrix of shape {incidence_2_t_list[i_complex].shape}."
)
print(
    f"The {i_complex}th cell complex has an adjacency_0 matrix of shape {adjacency_0_list[i_complex].shape}."
)
The 6th cell complex has an incidence_2_t matrix of shape torch.Size([500, 750]).
The 6th cell complex has an adjacency_0 matrix of shape torch.Size([252, 252]).

Create the Neural Network#

Using the CCXNLayer class, we create a neural network with stacked layers.

[7]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        in_channels_2,
        num_classes,
        n_layers=2,
        att=False,
    ):
        super().__init__()
        self.base_model = CCXN(
            in_channels_0,
            in_channels_1,
            in_channels_2,
            n_layers=n_layers,
            att=att,
        )
        self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)

    def forward(self, x_0, x_1, adjacency_0, incidence_2_t):
        x_0, x_1, x_2 = self.base_model(x_0, x_1, adjacency_0, incidence_2_t)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )
[8]:
in_channels_0 = x_0s[0].shape[-1]
in_channels_1 = x_1s[0].shape[-1]
in_channels_2 = 5
num_classes = 2
print(
    f"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}."
)
model = Network(in_channels_0, in_channels_1, in_channels_2, num_classes, n_layers=2)
model = model.to(device)
The dimension of input features on nodes, edges and faces are: 6, 10 and 5.
[9]:
model
[9]:
Network(
  (base_model): CCXN(
    (layers): ModuleList(
      (0-1): 2 x CCXNLayer(
        (conv_0_to_0): Conv()
        (conv_1_to_2): Conv()
      )
    )
  )
  (lin_0): Linear(in_features=6, out_features=2, bias=True)
  (lin_1): Linear(in_features=10, out_features=2, bias=True)
  (lin_2): Linear(in_features=5, out_features=2, bias=True)
)

Train the Neural Network#

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

[10]:
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()

We split the dataset into train and test sets.

[11]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
incidence_2_t_train, incidence_2_t_test = train_test_split(
    incidence_2_t_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

We train the CCXN using low amount of epochs: we keep training minimal for the purpose of rapid testing.

[12]:
test_interval = 2
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train,
        x_1_train,
        incidence_2_t_train,
        adjacency_0_train,
        y_train,
        strict=True,
    ):
        x_0, x_1, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(y).float().to(device),
        )
        incidence_2_t, adjacency_0 = (
            incidence_2_t.float().to(device),
            adjacency_0.float().to(device),
        )
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    if epoch_i % test_interval == 0:
        with torch.no_grad():
            train_mean_loss = np.mean(epoch_loss)
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test,
                x_1_test,
                incidence_2_t_test,
                adjacency_0_test,
                y_test,
                strict=True,
            ):
                x_0, x_1, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                incidence_2_t, adjacency_0 = (
                    incidence_2_t.float().to(device),
                    adjacency_0.float().to(device),
                )
                y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
                test_loss = loss_fn(y_hat, y)
            print(
                f"Epoch:{epoch_i}, Train Loss: {train_mean_loss:.4f} Test Loss: {test_loss:.4f}",
                flush=True,
            )
Epoch:2, Train Loss: 83.8803 Test Loss: 72.8717
/Users/gbg141/Documents/Projects/TopoModelX/venv_tmx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch:4, Train Loss: 80.8463 Test Loss: 74.7231
Epoch:6, Train Loss: 77.9684 Test Loss: 75.5384
Epoch:8, Train Loss: 75.5704 Test Loss: 76.0005
Epoch:10, Train Loss: 73.3453 Test Loss: 78.1194

Train the Neural Network with Attention#

Now we create a new neural network, that uses the attention mechanism.

[13]:
model = Network(
    in_channels_0, in_channels_1, in_channels_2, num_classes, n_layers=2, att=True
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()

We run the training for this neural network:

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

[14]:
test_interval = 2
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train,
        x_1_train,
        incidence_2_t_train,
        adjacency_0_train,
        y_train,
        strict=True,
    ):
        x_0, x_1, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(y).float().to(device),
        )
        incidence_2_t, adjacency_0 = (
            incidence_2_t.float().to(device),
            adjacency_0.float().to(device),
        )
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    if epoch_i % test_interval == 0:
        with torch.no_grad():
            train_mean_loss = np.mean(epoch_loss)
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test,
                x_1_test,
                incidence_2_t_test,
                adjacency_0_test,
                y_test,
                strict=True,
            ):
                x_0, x_1, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                incidence_2_t, adjacency_0 = (
                    incidence_2_t.float().to(device),
                    adjacency_0.float().to(device),
                )
                y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
                test_loss = loss_fn(y_hat, y)
            print(
                f"Epoch:{epoch_i}, Train Loss: {train_mean_loss:.4f} Test Loss: {test_loss:.4f}",
                flush=True,
            )
Epoch:2, Train Loss: 86.7494 Test Loss: 81.5152
Epoch:4, Train Loss: 79.9747 Test Loss: 85.7426
Epoch:6, Train Loss: 76.6452 Test Loss: 89.2596
Epoch:8, Train Loss: 74.2124 Test Loss: 88.7206
Epoch:10, Train Loss: 72.5803 Test Loss: 87.3506
[ ]: