Train a Hypersage TNN#

In this notebook, we will create and train HyperSAGE layer (Arya et al., 2020) - two-levels message passing strategy for hypergraphs learning. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph.

Following the “awesome-tnns” github repo.

🟥 \(\quad m_{y \rightarrow z}^{(0 \rightarrow 1)} = (B_1)^T_{zy} \cdot w_y \cdot (h_y^{(0)})^p\)

🟥 \(\quad m_z^{(0 \rightarrow 1)} = \left(\frac{1}{\vert \mathcal{B}(z)\vert}\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)}\right)^{\frac{1}{p}}\)

🟥 \(\quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = (B_1)_{xz} \cdot w_z \cdot (m_z^{(0 \rightarrow 1)})^p\)

🟧 \(\quad m_x^{(1 \rightarrow 0)} = \left(\frac{1}{\vert \mathcal{C}(x) \vert}\sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}\right)^{\frac{1}{p}}\)

🟩 \(\quad m_x^{(0)} = m_x^{(1 \rightarrow 0)}\)

🟦 \(\quad h_x^{t+1, (0)} = \sigma \left(\frac{m_x^{(0)} + h_x^{t,(0)}}{\lvert m_x^{(0)} + h_x^{t,(0)}\rvert} \cdot \Theta^t\right)\)

Additional theoretical clarifications#

Arya et al propose to interpret the propagation of information in a given hypergraph as a two-level aggregation problem, where the neighborhood of any node is divided into intra-edge neighbors and inter-edge neighbors. Given a hypergraph \(H=(\mathcal{V}, \mathcal{E})\), let \(\textbf{X}\) denote the feature matrix, such that \(\textbf{x}_{i} \in \textbf{X}\) is the feature set for node \(\textbf{v}_{i} \in \textbf{V}\) . For two-level aggregation, let \(\mathcal{F}_{1}(·)\) and \(\mathcal{F}_{2}(·)\) denote the intra-edge and inter-edge aggregation functions, respectively. Message passing at node vi for aggregation of information at the \(\mathcal{l}^{th}\) layer can then be stated as

$ \mathcal{x}{i,l}^{(e)} :nbsphinx-math:`leftarrow `:nbsphinx-math:`mathcal{F}`{1}({ \mathcal{x}{j,l-1} | :nbsphinx-math:`mathcal{v}`{j} \in `:nbsphinx-math:mathcal{N}`( \mathcal{v}_{i}, \textbf{e},:nbsphinx-math:alpha) }), $

$ \mathcal{x}{i,l} :nbsphinx-math:`leftarrow `:nbsphinx-math:`mathcal{x}`{i,l-1} + \mathcal{F}{2}({ :nbsphinx-math:`mathcal{x}`{i,l}^{(e)} | \mathcal{v}{i} :nbsphinx-math:`in {E}`( :nbsphinx-math:`mathcal{v}`{i}) }), $

where, $ \mathcal{x}_{i,l}^{(e)}$ refers to the aggregated feature set at \(\mathcal{v}_{i}\) obtained with intra-edge aggregation for edge \(\textbf{e}\).

[1]:
"""
This module contains the HyperSAGE class for hypergraph-based neural networks.

The AllSet class implements a specific hypergraph-based neural network architecture
used for solving certain types of problems.

Author: Your Name

"""

import numpy as np
import torch
import torch_geometric.datasets as geom_datasets
from torch_geometric.utils import to_undirected

from topomodelx.nn.hypergraph.hypersage import HyperSAGE

torch.manual_seed(0)

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

Pre-processing#

The first step is to import the dataset, Cora, a benchmark classification datase. We then lift the graph into our domain of choice, a hypergraph.

[3]:
cora = geom_datasets.Planetoid(root="tmp/", name="cora")
data = cora.data

x_0s = data.x
y = data.y
edge_index = data.edge_index

train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
/usr/local/lib/python3.11/site-packages/torch_geometric/data/in_memory_dataset.py:284: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
  warnings.warn(msg)

Now we retrieve the neighborhood structure (i.e. their representative matrice) that we will use to send messges from node to hyperedges. In the case of this architecture, we need the boundary matrix (or incidence matrix) \(B_1\) with shape \(n_\text{nodes} \times n_\text{edges}\).

In citation Cora dataset we lift graph structure to the hypergraph domain by creating hyperedges from 1-hop graph neighbourhood of each node.

[4]:
# Ensure the graph is undirected (optional but often useful for one-hop neighborhoods).
edge_index = to_undirected(edge_index)

# Create a list of one-hop neighborhoods for each node.
one_hop_neighborhoods = []
for node in range(data.num_nodes):
    # Get the one-hop neighbors of the current node.
    neighbors = data.edge_index[1, data.edge_index[0] == node]

    # Append the neighbors to the list of one-hop neighborhoods.
    one_hop_neighborhoods.append(neighbors.numpy())

# Detect and eliminate duplicate hyperedges.
unique_hyperedges = set()
hyperedges = []
for neighborhood in one_hop_neighborhoods:
    # Sort the neighborhood to ensure consistent comparison.
    neighborhood = tuple(sorted(neighborhood))
    if neighborhood not in unique_hyperedges:
        hyperedges.append(list(neighborhood))
        unique_hyperedges.add(neighborhood)

Additionally we print the statictis associated with obtained incidence matrix

[5]:
# Calculate hyperedge statistics.
hyperedge_sizes = [len(he) for he in hyperedges]
min_size = min(hyperedge_sizes)
max_size = max(hyperedge_sizes)
mean_size = np.mean(hyperedge_sizes)
median_size = np.median(hyperedge_sizes)
std_size = np.std(hyperedge_sizes)
num_single_node_hyperedges = sum(np.array(hyperedge_sizes) == 1)

# Print the hyperedge statistics.
print("Hyperedge statistics: ")
print("Number of hyperedges without duplicated hyperedges", len(hyperedges))
print(f"min = {min_size}, ")
print(f"max = {max_size}, ")
print(f"mean = {mean_size}, ")
print(f"median = {median_size}, ")
print(f"std = {std_size}, ")
print(f"Number of hyperedges with size equal to one = {num_single_node_hyperedges}")
Hyperedge statistics:
Number of hyperedges without duplicated hyperedges 2581
min = 1,
max = 168,
mean = 4.003099573808601,
median = 3.0,
std = 5.327622607829558,
Number of hyperedges with size equal to one = 412

Construct incidence matrix

[6]:
max_edges = len(hyperedges)
incidence_1 = np.zeros((x_0s.shape[0], max_edges))
for col, neighibourhood in enumerate(hyperedges):
    for row in neighibourhood:
        incidence_1[row, col] = 1

assert all(incidence_1.sum(0) > 0) is True, "Some hyperedges are empty"
assert all(incidence_1.sum(1) > 0) is True, "Some nodes are not in any hyperedges"
incidence_1 = torch.Tensor(incidence_1).to_sparse_coo()

Create the Neural Network#

Define the network that initializes the base model and sets up the readout operation. Different downstream tasks might require different pooling procedures.

[7]:
class Network(torch.nn.Module):
    """Network class that initializes the base model and readout layer.

    Base model parameters:
    ----------
    Reqired:
    in_channels : int
        Dimension of the input features.
    hidden_channels : int
        Dimension of the hidden features.

    Optitional:
    **kwargs : dict
        Additional arguments for the base model.

    Readout layer parameters:
    ----------
    out_channels : int
        Dimension of the output features.
    task_level : str
        Level of the task. Either "graph" or "node".
    """

    def __init__(
        self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
    ):
        super().__init__()

        # Define the model
        self.base_model = HyperSAGE(
            in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
        )

        # Readout
        self.linear = torch.nn.Linear(hidden_channels, out_channels)
        self.out_pool = task_level == "graph"

    def forward(self, x_0, incidence_1):
        # Base model
        x_0 = self.base_model(x_0, incidence_1)

        # Pool over all nodes in the hypergraph
        x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0

        return self.linear(x)

Initialize the model

[8]:
# Base model hyperparameters
in_channels = x_0s.shape[1]
hidden_channels = 128
n_layers = 1
mlp_num_layers = 1

# Readout hyperparameters
out_channels = torch.unique(y).shape[0]
task_level = "graph" if out_channels == 1 else "node"


model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=n_layers,
    device=device,
    task_level=task_level,
).to(device)

Train the Neural Network#

We specify the model, the loss, and an optimizer.

[9]:
# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.01)

# Categorial cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()


# Accuracy
def acc_fn(y, y_hat):
    return (y == y_hat).float().mean()
[10]:
x_0s = torch.tensor(x_0s)
x_0s, incidence_1, y = (
    x_0s.float().to(device),
    incidence_1.float().to(device),
    torch.tensor(y, dtype=torch.long).to(device),
)
/tmp/ipykernel_33450/1422611997.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  x_0s = torch.tensor(x_0s)
/tmp/ipykernel_33450/1422611997.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(y, dtype=torch.long).to(device),

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

[11]:
x_0s = torch.tensor(x_0s)
x_0s, incidence_1, y = (
    x_0s.float().to(device),
    incidence_1.float().to(device),
    torch.tensor(y, dtype=torch.long).to(device),
)
/tmp/ipykernel_33450/1422611997.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  x_0s = torch.tensor(x_0s)
/tmp/ipykernel_33450/1422611997.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(y, dtype=torch.long).to(device),

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

[12]:
torch.manual_seed(0)
test_interval = 5
num_epochs = 5

epoch_loss = []
for epoch_i in range(1, num_epochs + 1):
    model.train()

    opt.zero_grad()

    # Extract edge_index from sparse incidence matrix
    y_hat = model(x_0s, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])

    loss.backward()
    opt.step()
    epoch_loss.append(loss.item())

    if epoch_i % test_interval == 0:
        model.eval()
        y_hat = model(x_0s, incidence_1)

        loss = loss_fn(y_hat[train_mask], y[train_mask])
        print(f"Epoch: {epoch_i} ")
        print(
            f"Train_loss: {np.mean(epoch_loss):.4f}, acc: {acc_fn(y_hat[train_mask].argmax(1), y[train_mask]):.4f}",
            flush=True,
        )

        loss = loss_fn(y_hat[val_mask], y[val_mask])

        print(
            f"Val_loss: {loss:.4f}, Val_acc: {acc_fn(y_hat[val_mask].argmax(1), y[val_mask]):.4f}",
            flush=True,
        )

        loss = loss_fn(y_hat[test_mask], y[test_mask])
        print(
            f"Test_loss: {loss:.4f}, Test_acc: {acc_fn(y_hat[test_mask].argmax(1), y[test_mask]):.4f}",
            flush=True,
        )
Epoch: 5
Train_loss: 1.9424, acc: 0.5929
Val_loss: 1.9401, Val_acc: 0.2460
Test_loss: 1.9405, Test_acc: 0.2620
Epoch: 10
Train_loss: 1.9305, acc: 0.9357
Val_loss: 1.9221, Val_acc: 0.5680
Test_loss: 1.9220, Test_acc: 0.5580
Epoch: 15
Train_loss: 1.9105, acc: 0.9714
Val_loss: 1.8899, Val_acc: 0.6560
Test_loss: 1.8904, Test_acc: 0.6490
Epoch: 20
Train_loss: 1.8811, acc: 0.9786
Val_loss: 1.8450, Val_acc: 0.7260
Test_loss: 1.8466, Test_acc: 0.6990
Epoch: 25
Train_loss: 1.8412, acc: 0.9929
Val_loss: 1.7831, Val_acc: 0.7360
Test_loss: 1.7857, Test_acc: 0.7210
Epoch: 30
Train_loss: 1.7905, acc: 1.0000
Val_loss: 1.7067, Val_acc: 0.7420
Test_loss: 1.7103, Test_acc: 0.7220