Train a Hypergraph Message Passing Neural Network (HMPNN)#

In this notebook, we will create and train a Hypergraph Message Passing Neural Network in the hypergraph domain. This method is introduced in the paper Message Passing Neural Networks for Hypergraphs by Heydari et Livi 2022. We will use a benchmark dataset, Cora, a collection of 2708 academic papers and 5429 citation relations, to do the task of node classification. There are 7 category labels, namely Case_Based, Genetic_Algorithms, Neural_Networks, Probabilistic_Methods, Reinforcement_Learning, Rule_Learning and Theory.

Each document is initially represented as a binary vector of length 1433, standing for a unique subset of the words within the papers, in which a value of 1 means the presence of its corresponding word in the paper.

[1]:
import torch
import torch_geometric.datasets as geom_datasets
from sklearn.metrics import accuracy_score

from topomodelx.nn.hypergraph.hmpnn import HMPNN

torch.manual_seed(0)

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

Pre-processing#

Here we download the dataset. It contains initial representation of nodes, the adjacency information, category labels and train-val-test masks.

[3]:
dataset = geom_datasets.Planetoid(root="tmp/", name="cora")[0]

Below, we construct the incidence matrix (\(B_1\)) which is of shape \(n_\text{nodes} \times n_\text{edges}\).

[4]:
dataset["incidence_1"] = torch.sparse_coo_tensor(
    dataset["edge_index"], torch.ones(dataset["edge_index"].shape[1]), dtype=torch.long
)
dataset = dataset.to(device)
[5]:
x_0s = dataset["x"]
y = dataset["y"]
incidence_1 = dataset["incidence_1"]

Train the Neural Network#

We then specify the hyperparameters and construct the model, the loss and optimizer.

[6]:
class Network(torch.nn.Module):
    """Network class that initializes the base model and readout layer.

    Base model parameters:
    ----------
    Reqired:
    in_channels : int
        Dimension of the input features.
    hidden_channels : int
        Dimension of the hidden features.

    Optitional:
    **kwargs : dict
        Additional arguments for the base model.

    Readout layer parameters:
    ----------
    out_channels : int
        Dimension of the output features.
    task_level : str
        Level of the task. Either "graph" or "node".
    """

    def __init__(
        self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
    ):
        super().__init__()

        # Define the model
        self.base_model = HMPNN(
            in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
        )

        # Readout
        self.linear = torch.nn.Linear(hidden_channels, out_channels)
        self.out_pool = task_level == "graph"

    def forward(self, x_0, x_1, incidence_1):
        # Base model
        x_0, x_1 = self.base_model(x_0, x_1, incidence_1)

        # Pool over all nodes in the hypergraph
        x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0

        return self.linear(x)

Initialize the model

[7]:
# Base model hyperparameters
in_channels = x_0s.shape[1]
hidden_channels = 128
n_layers = 1

# Readout hyperparameters
out_channels = torch.unique(y).shape[0]
task_level = "graph" if out_channels == 1 else "node"


model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=n_layers,
    task_level=task_level,
).to(device)
[8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()


train_mask = dataset["train_mask"]
val_mask = dataset["val_mask"]
test_mask = dataset["test_mask"]

Now it’s time to train the model, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

[9]:
torch.manual_seed(0)
test_interval = 5
num_epochs = 5


initial_x_1 = torch.zeros_like(x_0s)
for epoch in range(1, num_epochs + 1):
    model.train()
    optimizer.zero_grad()
    y_hat = model(x_0s, initial_x_1, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()

    train_loss = loss.item()
    y_pred = y_hat.argmax(dim=-1)
    train_acc = accuracy_score(y[train_mask].cpu(), y_pred[train_mask].cpu())

    if epoch % test_interval == 0:
        model.eval()

        y_hat = model(x_0s, initial_x_1, incidence_1)
        val_loss = loss_fn(y_hat[val_mask], y[val_mask]).item()
        y_pred = y_hat.argmax(dim=-1)
        val_acc = accuracy_score(y[val_mask].cpu(), y_pred[val_mask].cpu())

        test_loss = loss_fn(y_hat[test_mask], y[test_mask]).item()
        y_pred = y_hat.argmax(dim=-1)
        test_acc = accuracy_score(y[test_mask].cpu(), y_pred[test_mask].cpu())
        print(
            f"Epoch: {epoch + 1} train loss: {train_loss:.4f} train acc: {train_acc:.2f} "
            f" val loss: {val_loss:.4f} val acc: {val_acc:.2f}"
            f" test loss: {test_acc:.4f} val acc: {test_acc:.2f}"
        )
Epoch: 6 train loss: 1.2665 train acc: 0.82  val loss: 1.9848 val acc: 0.23 test loss: 0.2270 val acc: 0.23
Epoch: 11 train loss: 0.8439 train acc: 0.99  val loss: 1.7569 val acc: 0.38 test loss: 0.3970 val acc: 0.40
Epoch: 16 train loss: 0.5119 train acc: 1.00  val loss: 1.6846 val acc: 0.39 test loss: 0.4160 val acc: 0.42
Epoch: 21 train loss: 0.2717 train acc: 1.00  val loss: 1.5872 val acc: 0.43 test loss: 0.4500 val acc: 0.45
Epoch: 26 train loss: 0.1571 train acc: 1.00  val loss: 1.6143 val acc: 0.41 test loss: 0.4230 val acc: 0.42
Epoch: 31 train loss: 0.0816 train acc: 1.00  val loss: 1.5894 val acc: 0.45 test loss: 0.4490 val acc: 0.45
Epoch: 36 train loss: 0.0478 train acc: 1.00  val loss: 1.6020 val acc: 0.46 test loss: 0.4630 val acc: 0.46
Epoch: 41 train loss: 0.0298 train acc: 1.00  val loss: 1.6153 val acc: 0.47 test loss: 0.4670 val acc: 0.47
Epoch: 46 train loss: 0.0214 train acc: 1.00  val loss: 1.6499 val acc: 0.47 test loss: 0.4720 val acc: 0.47
Epoch: 51 train loss: 0.0160 train acc: 1.00  val loss: 1.6764 val acc: 0.48 test loss: 0.4830 val acc: 0.48
Epoch: 56 train loss: 0.0149 train acc: 1.00  val loss: 1.6986 val acc: 0.48 test loss: 0.4900 val acc: 0.49
Epoch: 61 train loss: 0.0123 train acc: 1.00  val loss: 1.6888 val acc: 0.47 test loss: 0.4920 val acc: 0.49
Epoch: 66 train loss: 0.0097 train acc: 1.00  val loss: 1.6670 val acc: 0.48 test loss: 0.4970 val acc: 0.50
Epoch: 71 train loss: 0.0078 train acc: 1.00  val loss: 1.6547 val acc: 0.49 test loss: 0.5030 val acc: 0.50
Epoch: 76 train loss: 0.0072 train acc: 1.00  val loss: 1.6484 val acc: 0.49 test loss: 0.5030 val acc: 0.50
Epoch: 81 train loss: 0.0066 train acc: 1.00  val loss: 1.6378 val acc: 0.49 test loss: 0.5100 val acc: 0.51
Epoch: 86 train loss: 0.0064 train acc: 1.00  val loss: 1.6507 val acc: 0.49 test loss: 0.5110 val acc: 0.51
Epoch: 91 train loss: 0.0060 train acc: 1.00  val loss: 1.6745 val acc: 0.50 test loss: 0.5100 val acc: 0.51
Epoch: 96 train loss: 0.0051 train acc: 1.00  val loss: 1.6682 val acc: 0.50 test loss: 0.5150 val acc: 0.52
Epoch: 101 train loss: 0.0047 train acc: 1.00  val loss: 1.6412 val acc: 0.50 test loss: 0.5190 val acc: 0.52
[ ]: