Train a CW Network (CWN)#

We create and train a specific version of the CWN originally proposed in Bodnar et. al : Weisfeiler and Lehman Go Cellular: CW Networks (2021).

The Neural Network:#

The equations for a single layer of this neural network are given by:

🟥 \(\quad m_{y \rightarrow \{z\} \rightarrow x}^{(r \rightarrow r' \rightarrow r)} = M_{\mathcal{L}\uparrow}(h_x^{t,(r)}, h_y^{t,(r)}, h_z^{t,(r')})\)

🟥 \(\quad m_{y \rightarrow x}^{(r'' \rightarrow r)} = M_{\mathcal{B}}(h_x^{t,(r)}, h_y^{t,(r'')})\)

🟧 \(\quad m_x^{(r'' \rightarrow r)} = AGG_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r'' \rightarrow r)}\)

🟧 \(\quad m_x^{(r \rightarrow r' \rightarrow r)} = AGG_{y \in \mathcal{L}(x)} m_{y \rightarrow \{z\} \rightarrow x}^{(r \rightarrow r' \rightarrow r)}\)

🟩 \(\quad m_x^{(r)} = AGG_{\mathcal{N}\_k \in \mathcal{N}} (m_x^k)\)

🟦 \(\quad h_x^{t+1,(r)} = U\left(h_x^{t,(r)}, m_x^{(r)}\right)\)

Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).

The Task:#

We train this model to perform entire complex classification on a small version of shrec16.

Set-up#

 In [1]:
import numpy as np
import toponetx.datasets as datasets
import torch
from sklearn.model_selection import train_test_split

from topomodelx.nn.cell.cwn import CWN
from topomodelx.utils.sparse import from_sparse

torch.manual_seed(0)
 Out [1]:
<torch._C.Generator at 0x16d07b750>

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

 In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

The first step is to import the dataset, shrec16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a cell complex.

We also retrieve: - input signals x_0, x_1, x_2 on the nodes (0-cells), edges (1-cells), and faces (2-cells) for each complex: these will be the model’s inputs, - a scalar classification label y associated to the cell complex.

 In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...

done!
 In [4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes "
    f"with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges "
    f"with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces "
    f"with features of dimension {x_2s[i_complex].shape[1]}."
)
The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.

We lift each simplicial complex into a cell complex.

Then, we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on each cell complex. In the case of this architecture, we need the upper adjacency matrix \(A_{\uparrow, r}\), the coboundary matrix \(B_r^{\intercal}\), and the boundary matrix \(B_{r+1}\).

 In [5]:
cc_list = []
incidence_2_list = []
adjacency_1_list = []
incidence_1_t_list = []

for simplex in simplexes:
    cell_complex = simplex.to_cell_complex()
    cc_list.append(cell_complex)

    incidence_2 = cell_complex.incidence_matrix(rank=2)
    adjacency_1 = cell_complex.adjacency_matrix(rank=1)
    incidence_1_t = cell_complex.incidence_matrix(rank=1).T

    incidence_2 = from_sparse(incidence_2)
    adjacency_1 = from_sparse(adjacency_1)
    incidence_1_t = from_sparse(incidence_1_t)

    incidence_2_list.append(incidence_2)
    adjacency_1_list.append(adjacency_1)
    incidence_1_t_list.append(incidence_1_t)
 In [6]:
i_complex = 6

print(
    f"The {i_complex}th cell complex has an adjacency_1 matrix "
    f"of shape {adjacency_1_list[i_complex].shape}."
)
print(
    f"The {i_complex}th cell complex has an incidence_2 matrix "
    f"of shape {incidence_2_list[i_complex].shape}."
)
print(
    f"The {i_complex}th cell complex has an incidence_1_t matrix "
    f"of shape {incidence_1_t_list[i_complex].shape}."
)
The 6th cell complex has an adjacency_1 matrix of shape torch.Size([750, 750]).
The 6th cell complex has an incidence_2 matrix of shape torch.Size([750, 500]).
The 6th cell complex has an incidence_1_t matrix of shape torch.Size([750, 252]).

Create the Neural Network#

Using the CWNLayer class, we create a neural network with stacked layers.

 In [7]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        in_channels_2,
        hid_channels=16,
        num_classes=1,
        n_layers=2,
    ):
        super().__init__()
        self.base_model = CWN(
            in_channels_0,
            in_channels_1,
            in_channels_2,
            hid_channels=hid_channels,
            n_layers=n_layers,
        )
        self.lin_0 = torch.nn.Linear(hid_channels, num_classes)
        self.lin_1 = torch.nn.Linear(hid_channels, num_classes)
        self.lin_2 = torch.nn.Linear(hid_channels, num_classes)

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        adjacency_1,
        incidence_2,
        incidence_1_t,
    ):
        x_0, x_1, x_2 = self.base_model(
            x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t
        )
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)

        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0

        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0

        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )
 In [8]:
in_channels_0 = x_0s[0].shape[-1]
in_channels_1 = x_1s[0].shape[-1]
in_channels_2 = x_2s[0].shape[-1]

print(
    f"The dimensions of input features on nodes, edges and faces are "
    f"{in_channels_0}, {in_channels_1} and {in_channels_2}, respectively."
)
model = Network(
    in_channels_0,
    in_channels_1,
    in_channels_2,
    hid_channels=16,
    num_classes=1,
    n_layers=2,
)
model = model.to(device)
The dimensions of input features on nodes, edges and faces are 6, 10 and 7, respectively.

Train the Neural Network#

We instantiate a model, specify an optimizer, define a loss function.

 In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

We split the dataset into train and test sets.

 In [10]:
test_size = 0.2

x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

adjacency_1_train, adjacency_1_test = train_test_split(
    adjacency_1_list, test_size=test_size, shuffle=False
)
incidence_2_train, incidence_2_test = train_test_split(
    incidence_2_list, test_size=test_size, shuffle=False
)
incidence_1_t_train, incidence_1_t_test = train_test_split(
    incidence_1_t_list, test_size=test_size, shuffle=False
)

y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

We train the CWN using low amount of epochs: we keep training minimal for the purpose of rapid testing.

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

 In [11]:
test_interval = 2
num_epochs = 10

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()

    for x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t, y in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        adjacency_1_train,
        incidence_2_train,
        incidence_1_t_train,
        y_train,
        strict=True,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor([y]).float().to(device),
        )

        adjacency_1 = adjacency_1.float().to(device)
        incidence_2 = incidence_2.float().to(device)
        incidence_1_t = incidence_1_t.float().to(device)

        optimizer.zero_grad()
        y_hat = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())

    if epoch_i % test_interval == 0:
        with torch.no_grad():
            train_mean_loss = np.mean(epoch_loss)
            for x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t, y in zip(
                x_0_test,
                x_1_test,
                x_2_test,
                adjacency_1_test,
                incidence_2_test,
                incidence_1_t_test,
                y_test,
                strict=True,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor([y]).float().to(device),
                )

                adjacency_1 = adjacency_1.float().to(device)
                incidence_2 = incidence_2.float().to(device)
                incidence_1_t = incidence_1_t.float().to(device)

                y_hat = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
                test_loss = criterion(y_hat, y)
            print(
                f"Epoch:{epoch_i}, Train Loss: {train_mean_loss:.4f} Test Loss: {test_loss:.4f}",
                flush=True,
            )
Epoch:2, Train Loss: 83.8053 Test Loss: 73.7517
Epoch:4, Train Loss: 81.9551 Test Loss: 50.2781
Epoch:6, Train Loss: 78.3991 Test Loss: 49.9035
Epoch:8, Train Loss: 75.8110 Test Loss: 45.7197
Epoch:10, Train Loss: 74.3838 Test Loss: 40.5566
 In [ ]: