{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Hypergraph Message Passing Neural Network (HMPNN)\n", "\n", "In this notebook, we will create and train a Hypergraph Message Passing Neural Network in the hypergraph domain. This method is introduced in the paper [Message Passing Neural Networks for\n", "Hypergraphs](https://arxiv.org/abs/2203.16995) by Heydari et Livi 2022. We will use a benchmark dataset, Cora, a collection of 2708 academic papers and 5429 citation relations, to do the task of node classification. There are 7 category labels, namely `Case_Based`, `Genetic_Algorithms`, `Neural_Networks`, `Probabilistic_Methods`, `Reinforcement_Learning`, `Rule_Learning` and `Theory`.\n", "\n", "Each document is initially represented as a binary vector of length 1433, standing for a unique subset of the words within the papers, in which a value of 1 means the presence of its corresponding word in the paper." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:51.222779223Z", "start_time": "2023-06-01T16:14:49.575421023Z" } }, "outputs": [], "source": [ "import torch\n", "import torch_geometric.datasets as geom_datasets\n", "from sklearn.metrics import accuracy_score\n", "\n", "from topomodelx.nn.hypergraph.hmpnn import HMPNN\n", "\n", "torch.manual_seed(0)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If GPU's are available, we will make use of them. Otherwise, this will run on CPU." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:51.959770754Z", "start_time": "2023-06-01T16:14:51.956096841Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "Here we download the dataset. It contains initial representation of nodes, the adjacency information, category labels and train-val-test masks." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dataset = geom_datasets.Planetoid(root=\"tmp/\", name=\"cora\")[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below, we construct the incidence matrix ($B_1$) which is of shape $n_\\text{nodes} \\times n_\\text{edges}$." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "dataset[\"incidence_1\"] = torch.sparse_coo_tensor(\n", " dataset[\"edge_index\"], torch.ones(dataset[\"edge_index\"].shape[1]), dtype=torch.long\n", ")\n", "dataset = dataset.to(device)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "x_0s = dataset[\"x\"]\n", "y = dataset[\"y\"]\n", "incidence_1 = dataset[\"incidence_1\"]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train the Neural Network\n", "\n", "We then specify the hyperparameters and construct the model, the loss and optimizer." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " \"\"\"Network class that initializes the base model and readout layer.\n", "\n", " Base model parameters:\n", " ----------\n", " Reqired:\n", " in_channels : int\n", " Dimension of the input features.\n", " hidden_channels : int\n", " Dimension of the hidden features.\n", "\n", " Optitional:\n", " **kwargs : dict\n", " Additional arguments for the base model.\n", "\n", " Readout layer parameters:\n", " ----------\n", " out_channels : int\n", " Dimension of the output features.\n", " task_level : str\n", " Level of the task. Either \"graph\" or \"node\".\n", " \"\"\"\n", "\n", " def __init__(\n", " self, in_channels, hidden_channels, out_channels, task_level=\"graph\", **kwargs\n", " ):\n", " super().__init__()\n", "\n", " # Define the model\n", " self.base_model = HMPNN(\n", " in_channels=in_channels, hidden_channels=hidden_channels, **kwargs\n", " )\n", "\n", " # Readout\n", " self.linear = torch.nn.Linear(hidden_channels, out_channels)\n", " self.out_pool = task_level == \"graph\"\n", "\n", " def forward(self, x_0, x_1, incidence_1):\n", " # Base model\n", " x_0, x_1 = self.base_model(x_0, x_1, incidence_1)\n", "\n", " # Pool over all nodes in the hypergraph\n", " x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0\n", "\n", " return self.linear(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Base model hyperparameters\n", "in_channels = x_0s.shape[1]\n", "hidden_channels = 128\n", "n_layers = 1\n", "\n", "# Readout hyperparameters\n", "out_channels = torch.unique(y).shape[0]\n", "task_level = \"graph\" if out_channels == 1 else \"node\"\n", "\n", "\n", "model = Network(\n", " in_channels=in_channels,\n", " hidden_channels=hidden_channels,\n", " out_channels=out_channels,\n", " n_layers=n_layers,\n", " task_level=task_level,\n", ").to(device)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:58.153514385Z", "start_time": "2023-06-01T16:14:57.243596119Z" } }, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "\n", "\n", "train_mask = dataset[\"train_mask\"]\n", "val_mask = dataset[\"val_mask\"]\n", "test_mask = dataset[\"test_mask\"]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now it's time to train the model, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:15:01.683216142Z", "start_time": "2023-06-01T16:15:00.727075750Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 6 train loss: 1.2665 train acc: 0.82 val loss: 1.9848 val acc: 0.23 test loss: 0.2270 val acc: 0.23\n", "Epoch: 11 train loss: 0.8439 train acc: 0.99 val loss: 1.7569 val acc: 0.38 test loss: 0.3970 val acc: 0.40\n", "Epoch: 16 train loss: 0.5119 train acc: 1.00 val loss: 1.6846 val acc: 0.39 test loss: 0.4160 val acc: 0.42\n", "Epoch: 21 train loss: 0.2717 train acc: 1.00 val loss: 1.5872 val acc: 0.43 test loss: 0.4500 val acc: 0.45\n", "Epoch: 26 train loss: 0.1571 train acc: 1.00 val loss: 1.6143 val acc: 0.41 test loss: 0.4230 val acc: 0.42\n", "Epoch: 31 train loss: 0.0816 train acc: 1.00 val loss: 1.5894 val acc: 0.45 test loss: 0.4490 val acc: 0.45\n", "Epoch: 36 train loss: 0.0478 train acc: 1.00 val loss: 1.6020 val acc: 0.46 test loss: 0.4630 val acc: 0.46\n", "Epoch: 41 train loss: 0.0298 train acc: 1.00 val loss: 1.6153 val acc: 0.47 test loss: 0.4670 val acc: 0.47\n", "Epoch: 46 train loss: 0.0214 train acc: 1.00 val loss: 1.6499 val acc: 0.47 test loss: 0.4720 val acc: 0.47\n", "Epoch: 51 train loss: 0.0160 train acc: 1.00 val loss: 1.6764 val acc: 0.48 test loss: 0.4830 val acc: 0.48\n", "Epoch: 56 train loss: 0.0149 train acc: 1.00 val loss: 1.6986 val acc: 0.48 test loss: 0.4900 val acc: 0.49\n", "Epoch: 61 train loss: 0.0123 train acc: 1.00 val loss: 1.6888 val acc: 0.47 test loss: 0.4920 val acc: 0.49\n", "Epoch: 66 train loss: 0.0097 train acc: 1.00 val loss: 1.6670 val acc: 0.48 test loss: 0.4970 val acc: 0.50\n", "Epoch: 71 train loss: 0.0078 train acc: 1.00 val loss: 1.6547 val acc: 0.49 test loss: 0.5030 val acc: 0.50\n", "Epoch: 76 train loss: 0.0072 train acc: 1.00 val loss: 1.6484 val acc: 0.49 test loss: 0.5030 val acc: 0.50\n", "Epoch: 81 train loss: 0.0066 train acc: 1.00 val loss: 1.6378 val acc: 0.49 test loss: 0.5100 val acc: 0.51\n", "Epoch: 86 train loss: 0.0064 train acc: 1.00 val loss: 1.6507 val acc: 0.49 test loss: 0.5110 val acc: 0.51\n", "Epoch: 91 train loss: 0.0060 train acc: 1.00 val loss: 1.6745 val acc: 0.50 test loss: 0.5100 val acc: 0.51\n", "Epoch: 96 train loss: 0.0051 train acc: 1.00 val loss: 1.6682 val acc: 0.50 test loss: 0.5150 val acc: 0.52\n", "Epoch: 101 train loss: 0.0047 train acc: 1.00 val loss: 1.6412 val acc: 0.50 test loss: 0.5190 val acc: 0.52\n" ] } ], "source": [ "torch.manual_seed(0)\n", "test_interval = 5\n", "num_epochs = 5\n", "\n", "\n", "initial_x_1 = torch.zeros_like(x_0s)\n", "for epoch in range(1, num_epochs + 1):\n", " model.train()\n", " optimizer.zero_grad()\n", " y_hat = model(x_0s, initial_x_1, incidence_1)\n", " loss = loss_fn(y_hat[train_mask], y[train_mask])\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss = loss.item()\n", " y_pred = y_hat.argmax(dim=-1)\n", " train_acc = accuracy_score(y[train_mask].cpu(), y_pred[train_mask].cpu())\n", "\n", " if epoch % test_interval == 0:\n", " model.eval()\n", "\n", " y_hat = model(x_0s, initial_x_1, incidence_1)\n", " val_loss = loss_fn(y_hat[val_mask], y[val_mask]).item()\n", " y_pred = y_hat.argmax(dim=-1)\n", " val_acc = accuracy_score(y[val_mask].cpu(), y_pred[val_mask].cpu())\n", "\n", " test_loss = loss_fn(y_hat[test_mask], y[test_mask]).item()\n", " y_pred = y_hat.argmax(dim=-1)\n", " test_acc = accuracy_score(y[test_mask].cpu(), y_pred[test_mask].cpu())\n", " print(\n", " f\"Epoch: {epoch + 1} train loss: {train_loss:.4f} train acc: {train_acc:.2f} \"\n", " f\" val loss: {val_loss:.4f} val acc: {val_acc:.2f}\"\n", " f\" test loss: {test_acc:.4f} val acc: {test_acc:.2f}\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 2 }