Train a Hypergraph Neural Network#

In this notebook, we will create and train a two-step message passing network HyperGAT (Ding et al., 2020) in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph.

Given a hypergraph \(G=(\mathcal{V}, \mathcal{E})\), where \(|\mathcal{V}|=n, |\mathcal{V}|=m\), let \(X \in \mathbb{R}^{n \times d}\) and \(Z \in \mathbb{R}^{m \times d'}\) denote the hidden node and hyperedge representations, respectively.

🟥 \(\quad m_{y \rightarrow z}^{(0 \rightarrow 1) } = (B^T_1\odot att(h_{y \in \mathcal{B}(z)}^{t,(0)}))\_{zy} \cdot h^{t,(0)}y \cdot \Theta^{t,(0)}\)

🟧 \(\quad m_z^{(1)} = \sigma(\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)})\)

🟥 \(\quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = (B_1 \odot att(h_{z \in \mathcal{C}(x)}^{t,(1)}))\_{xz} \cdot m_{z}^{(1)} \cdot \Theta^{t,(1)}\)

🟧 \(\quad m_{x}^{(0)} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1\rightarrow0)}\)

🟩 \(\quad m_x = m_{x}^{(0)}\)

🟦 \(\quad h_x^{t+1, (0)} = \sigma(m_x)\)

Given a specific node \(\mathcal{v}_{i}\) , HyperGAT layer first learns the representations of all its connected hyperedges \(\mathcal{E}_{i}\) . As not all the nodes in a hyperedge \(\mathcal{e}_{j} \in \mathcal{E}_{i}\) contribute equally to the hyperedge meaning, we introduce attention mechanism (i.e., node-level attention) to highlight those nodes that are important to the meaning of the hyperedge and then aggregate them to compute the hyperedge representation \(\mathcal{f}_{j}^{l}\). Formally:

\[ \mathcal{f}_{j}^{l} = \sigma (\sum_{\mathcal{u}_{k} \in \mathcal{e}_{j}} \alpha_{jk} \mathcal{W}_{1} \mathcal{h}_{k}^{l-1})\]

where \(\sigma\) is the nonlinearity such as ReLU and \(\mathcal{W}_{1}\) is a trainable weight matrix. \(\alpha_{jk}\) denotes the attention coefficient of node \(\mathcal{v}_{k}\) in the hyperedge \(\mathcal{e}_{j}\) , which can be computed by:

\[\alpha_{jk} = \frac{\operatorname{exp}(a_{1}^{T}u_{k})}{\sum\limits_{\mathcal{u}_{p} \in \mathcal{e}_{j}} \operatorname{exp}(a_{1}^{T}u_{p})}\]

where \(a_{1}^{T}\) is a weight vector (a.k.a, context vector).

Edge-level Attention. With all the hyperedges representations $ \left`{ :nbsphinx-math:mathcal{f}`{j}^{l}| :nbsphinx-math:`forall{mathcal{e}_{j}}` :nbsphinx-math:`in `:nbsphinx-math:`mathcal{E}`{i} \right}$, we again apply an edge-level attention mechanism to highlight the informative hyperedges for learning the next-layer representation of node vi . This process can be formally expressed as:

\[\mathcal{h}_{i}^{l} = \sigma (\sum_{\mathcal{e}_{j} \in \mathcal{E}_{i}} \beta_{ij} \mathcal{W}_{2} \mathcal{f}_{j}^{l})\]

where \(\mathcal{h}_{i}^{l}\) is the output representation of node \(\mathcal{v}_{i}\) and \(\mathcal{W}_{2}\) is a weight matrix. \(\beta_{ij}\) denotes the attention coefficient of hyperedge \(\mathcal{e}_{j}\) on node \(\mathcal{v}_{i}\) , which can be computed by:

\[\beta_{ij} = \frac{\operatorname{exp}(a_{2}^{T}v_{j})}{\sum\limits_{\mathcal{e}_{p} \in \mathcal{E}_{i}} \operatorname{exp}(a_{2}^{T}v_{p})}\]
\[\mathcal{v}_{j} = \operatorname{LeakyRELU} ([ \mathcal{W}_{2}\mathcal{f}_{j}^{l} || \mathcal{W}_{1}\mathcal{h}_{i}^{l-1} ])\]

where \(\mathcal{a}_{2}^{T}\) is another weight (context) vector for measuring the importance of the hyperedges and || is the concatenation operation.

[1]:
import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split

from topomodelx.nn.hypergraph.hypergat import HyperGAT
from topomodelx.utils.sparse import from_sparse
[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

Pre-processing#

Import data#

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:

  • input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input

  • the label associated to the hypergraph

[3]:
shrec, _ = tnx.datasets.shrec_16(size="small")

x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...

done!
[4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)
The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.

Define neighborhood structures and lift into hypergraph domain.#

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) \(B_1\) with shape \(n_\text{nodes} \times n_\text{edges}\).

Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges.

[5]:
hg_list = []
incidence_1_list = []
for simplex in simplexes:
    incidence_1 = simplex.incidence_matrix(rank=1, signed=False)
    hg = simplex.to_hypergraph()
    hg_list.append(hg)

# Extract hypergraphs incident matrices from collected hypergraphs
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = from_sparse(incidence_1)
    incidence_1_list.append(incidence_1)
[6]:
i_complex = 6
print(
    f"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}."
)
The 6th hypergraph has an incidence matrix of shape torch.Size([252, 1250]).

Train the Neural Network#

Define the network that initializes the base model and sets up the readout operation. Different downstream tasks might require different pooling procedures.

[7]:
class Network(torch.nn.Module):
    """Network class that initializes the AllSet model and readout layer.

    Base model parameters:
    ----------
    Reqired:
    in_channels : int
        Dimension of the input features.
    hidden_channels : int
        Dimension of the hidden features.

    Optitional:
    **kwargs : dict
        Additional arguments for the base model.

    Readout layer parameters:
    ----------
    out_channels : int
        Dimension of the output features.
    task_level : str
        Level of the task. Either "graph" or "node".
    """

    def __init__(
        self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
    ):
        super().__init__()

        # Define the model
        self.base_model = HyperGAT(
            in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
        )

        # Readout
        self.linear = torch.nn.Linear(hidden_channels, out_channels)
        self.out_pool = task_level == "graph"

    def forward(self, x_0, incidence_1):
        # Base model
        x_0, x_1 = self.base_model(x_0, incidence_1)

        # Pool over all nodes in the hypergraph
        x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0

        return self.linear(x)

Initialize the model

[8]:
# Base model hyperparameters
in_channels = x_0s[0].shape[1]
hidden_channels = 32
out_dim = 1
n_layers = 3

# Readout hyperparameters
out_channels = 1
task_level = "graph"


model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=n_layers,
    task_level=task_level,
).to(device)
[9]:
# in_channels = x_0s[0].shape[1]
# hidden_channels = 32
# out_dim = 1
# n_layers = 3

# # Define the model
# model = HyperGAT(
#     in_channels=in_channels,
#     hidden_channels=hidden_channels,
#     out_channels=out_dim,
#     n_layers=n_layers
#     )
# model = model.to(device)

# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()
[10]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

[11]:
test_interval = 5
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, incidence_1, y in zip(x_0_train, incidence_1_train, y_train, strict=True):
        x_0 = torch.tensor(x_0)
        x_0, incidence_1, y = (
            x_0.float().to(device),
            incidence_1.float().to(device),
            torch.tensor(y, dtype=torch.float).to(device),
        )
        opt.zero_grad()
        # Extract edge_index from sparse incidence matrix
        # edge_index, _ = to_edge_index(incidence_1)
        y_hat = model(x_0, incidence_1)
        loss = loss_fn(y_hat, y)

        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    if epoch_i % test_interval == 0:
        with torch.no_grad():
            train_loss = np.mean(epoch_loss)

            test_epoch_loss = []
            for x_0, incidence_1, y in zip(
                x_0_test, incidence_1_test, y_test, strict=True
            ):
                x_0 = torch.tensor(x_0)
                x_0, incidence_1, y = (
                    x_0.float().to(device),
                    incidence_1.float().to(device),
                    torch.tensor(y, dtype=torch.float).to(device),
                )
                y_hat = model(x_0, incidence_1)
                loss = loss_fn(y_hat, y)
                test_epoch_loss.append(loss.item())

            print(
                f"Epoch: {epoch_i} train_loss {train_loss:.4f} test_loss: {np.mean(test_epoch_loss):.4f}",
                flush=True,
            )
/usr/local/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch: 5 train_loss 7195.0287 test_loss: 19283.1447
Epoch: 10 train_loss 1705.4479 test_loss: 4695.8421
Epoch: 15 train_loss 2624.2060 test_loss: 3844.4079
Epoch: 20 train_loss 6754.1770 test_loss: 4970.3517