Train a DHGCN TNN#
In this notebook, we will create and train a two-step message passing network in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph.
[1]:
import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split
from topomodelx.nn.hypergraph.dhgcn import DHGCN
from topomodelx.utils.sparse import from_sparse
If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.
[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
Pre-processing#
Import data#
The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.
We will also retrieve:
input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
the label associated to the hypergraph
[3]:
shrec, _ = tnx.datasets.mesh.shrec_16(size="small")
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]
ys = shrec["label"]
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...
done!
[5]:
x_0s[4].shape, x_1s[0].shape, x_2s[0].shape
# list(dir(simplexes[0]))[40:]
[5]:
((252, 6), (750, 10), (500, 7))
[6]:
i_complex = 6
print(
f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)
The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.
Define neighborhood structures and lift into hypergraph domain.#
Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) \(B_1\) with shape \(n_\text{nodes} \times n_\text{edges}\).
Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges.
[7]:
hg_list = []
incidence_1_list = []
for simplex in simplexes:
incidence_1 = simplex.incidence_matrix(rank=1, signed=False)
incidence_1 = from_sparse(incidence_1)
incidence_1_list.append(incidence_1)
hg = simplex.to_hypergraph()
hg_list.append(hg)
[8]:
i_complex = 6
print(
f"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}."
)
The 6th hypergraph has an incidence matrix of shape torch.Size([252, 750]).
Create the Neural Network#
Define the network that initializes the base model and sets up the readout operation. Different downstream tasks might require different pooling procedures.
[9]:
class Network(torch.nn.Module):
"""Network class that initializes the base model and readout layer.
Base model parameters:
----------
Reqired:
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
Optitional:
**kwargs : dict
Additional arguments for the base model.
Readout layer parameters:
----------
out_channels : int
Dimension of the output features.
task_level : str
Level of the task. Either "graph" or "node".
"""
def __init__(
self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
):
super().__init__()
# Define the model
self.base_model = DHGCN(
in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
)
# Readout
self.linear = torch.nn.Linear(hidden_channels, out_channels)
self.out_pool = task_level == "graph"
def forward(self, x_0):
# Base model
x_0, x_1 = self.base_model(x_0)
# Pool over all nodes in the hypergraph
x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0
return self.linear(x)
[10]:
# Base model hyperparameters
in_channels = x_0s[0].shape[1]
hidden_channels = 6
n_layers = 2
# Readout hyperparameters
out_channels = 1
task_level = "graph" if out_channels == 1 else "node"
model = Network(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
n_layers=n_layers,
task_level=task_level,
).to(device)
Train the Neural Network#
We specify the model, the loss, and an optimizer.
[11]:
loss_fn = torch.nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)
Split the dataset into train and test sets.
[12]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)
The following cell performs the training, looping over the network for a low amount of epochs.
Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.
[13]:
test_interval = 1
num_epochs = 1
for epoch_i in range(1, num_epochs + 1):
epoch_loss = []
model.train()
for x_0, y in zip(x_0_train, y_train, strict=True):
x_0 = torch.tensor(x_0)
x_0, y = (
x_0.float().to(device),
torch.tensor(y, dtype=torch.float).to(device),
)
opt.zero_grad()
y_hat = model(x_0)
loss = loss_fn(y_hat, y)
loss.backward()
opt.step()
epoch_loss.append(loss.item())
print(
f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
flush=True,
)
if epoch_i % test_interval == 0:
with torch.no_grad():
for x_0, y in zip(x_0_test, y_test, strict=True):
x_0 = torch.tensor(x_0)
x_0, y = (
x_0.float().to(device),
torch.tensor(y, dtype=torch.float).to(device),
)
y_hat = model(x_0)
loss = loss_fn(y_hat, y)
print(f"Test_loss: {loss:.4f}", flush=True)
/usr/local/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
Epoch: 1 loss: 11062303.5101
Test_loss: 135039.8750
Epoch: 2 loss: 12746.0434
Test_loss: 7911.8130
Epoch: 3 loss: 1477.9950
Test_loss: 228.0758
Epoch: 4 loss: 702.3154
Test_loss: 10.2958
Epoch: 5 loss: 872.6976
Test_loss: 1025.2914