{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Hypergraph Neural Network\n", "\n", "In this notebook, we will create and train a two-step message passing network **HyperGAT** ([Ding et al., 2020](https://aclanthology.org/2020.emnlp-main.399.pdf)) in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. \n", "\n", "Given a hypergraph $G=(\\mathcal{V}, \\mathcal{E})$, where $|\\mathcal{V}|=n, |\\mathcal{V}|=m$, let $X \\in \\mathbb{R}^{n \\times d}$ and $Z \\in \\mathbb{R}^{m \\times d'}$ denote the hidden node and hyperedge representations, respectively. \n", "\n", "🟥 $\\quad m_{y \\rightarrow z}^{(0 \\rightarrow 1) } = (B^T_1\\odot att(h_{y \\in \\mathcal{B}(z)}^{t,(0)}))\\_{zy} \\cdot h^{t,(0)}y \\cdot \\Theta^{t,(0)}$ \n", "\n", "🟧 $\\quad m_z^{(1)} = \\sigma(\\sum_{y \\in \\mathcal{B}(z)} m_{y \\rightarrow z}^{(0 \\rightarrow 1)})$ \n", "\n", "🟥 $\\quad m_{z \\rightarrow x}^{(1 \\rightarrow 0)} = (B_1 \\odot att(h_{z \\in \\mathcal{C}(x)}^{t,(1)}))\\_{xz} \\cdot m_{z}^{(1)} \\cdot \\Theta^{t,(1)}$ \n", "\n", "🟧 $\\quad m_{x}^{(0)} = \\sum_{z \\in \\mathcal{C}(x)} m_{z \\rightarrow x}^{(1\\rightarrow0)}$ \n", "\n", "🟩 $\\quad m_x = m_{x}^{(0)}$ \n", "\n", "🟦 $\\quad h_x^{t+1, (0)} = \\sigma(m_x)$\n", "\n", "Given a specific node $\\mathcal{v}_{i}$ , HyperGAT layer first learns the representations of all its connected hyperedges $\\mathcal{E}_{i}$ . As not all the nodes in a hyperedge $\\mathcal{e}_{j} \\in \\mathcal{E}_{i}$ contribute equally to the hyperedge meaning, we introduce attention mechanism (i.e., node-level attention) to highlight those nodes that are important to the meaning of the hyperedge and then aggregate them to compute the hyperedge representation $\\mathcal{f}_{j}^{l}$. Formally:\n", "\n", "$$ \\mathcal{f}_{j}^{l} = \\sigma (\\sum_{\\mathcal{u}_{k} \\in \\mathcal{e}_{j}} \\alpha_{jk} \\mathcal{W}_{1} \n", "\\mathcal{h}_{k}^{l-1}) $$\n", "\n", "where $\\sigma$ is the nonlinearity such as ReLU and $\\mathcal{W}_{1}$ is a trainable weight matrix. $\\alpha_{jk}$ denotes the attention coefficient of node $\\mathcal{v}_{k}$ in the hyperedge $\\mathcal{e}_{j}$ , which can be computed by:\n", "\n", "$$ \\alpha_{jk} = \\frac{\\operatorname{exp}(a_{1}^{T}u_{k})}{\\sum\\limits_{\\mathcal{u}_{p} \\in \\mathcal{e}_{j}} \\operatorname{exp}(a_{1}^{T}u_{p})} $$\n", "\n", "where $a_{1}^{T}$ is a weight vector (a.k.a, context vector).\n", "\n", "Edge-level Attention. With all the hyperedges representations $ \\left\\{ \\mathcal{f}_{j}^{l}| \\forall{\\mathcal{e}_{j}} \\in \\mathcal{E}_{i} \\right\\}$, we again apply an edge-level attention mechanism to highlight the informative hyperedges for learning the next-layer representation of node vi . This process can be formally expressed as:\n", "\n", "$$ \\mathcal{h}_{i}^{l} = \\sigma (\\sum_{\\mathcal{e}_{j} \\in \\mathcal{E}_{i}} \\beta_{ij} \\mathcal{W}_{2} \\mathcal{f}_{j}^{l}) $$\n", "\n", "where $\\mathcal{h}_{i}^{l}$ is the output representation of node $\\mathcal{v}_{i}$ and $\\mathcal{W}_{2}$ is a weight matrix. $\\beta_{ij}$ denotes the attention coefficient of hyperedge $\\mathcal{e}_{j}$ on node $\\mathcal{v}_{i}$ , which can be computed by:\n", "\n", "$$ \\beta_{ij} = \\frac{\\operatorname{exp}(a_{2}^{T}v_{j})}{\\sum\\limits_{\\mathcal{e}_{p} \\in \\mathcal{E}_{i}} \\operatorname{exp}(a_{2}^{T}v_{p})} $$\n", "\n", "$$ \\mathcal{v}_{j} = \\operatorname{LeakyRELU} ([ \\mathcal{W}_{2}\\mathcal{f}_{j}^{l} || \\mathcal{W}_{1}\\mathcal{h}_{i}^{l-1} ]) $$\n", "\n", "where $\\mathcal{a}_{2}^{T}$ is another weight (context) vector for measuring the importance of the hyperedges and || is the concatenation operation.\n", " " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import toponetx.datasets as datasets\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "\n", "from topomodelx.nn.hypergraph.hypergat import HyperGAT\n", "from topomodelx.utils.sparse import from_sparse" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "## Import data ##\n", "\n", "The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.\n", "\n", "We will also retrieve:\n", "- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input\n", "- the label associated to the hypergraph" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading shrec 16 small dataset...\n", "\n", "done!\n" ] } ], "source": [ "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n", "\n", "shrec = {key: np.array(value) for key, value in shrec.items()}\n", "x_0s = shrec[\"node_feat\"]\n", "x_1s = shrec[\"edge_feat\"]\n", "x_2s = shrec[\"face_feat\"]\n", "\n", "ys = shrec[\"label\"]\n", "simplexes = shrec[\"complexes\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th simplicial complex has 252 nodes with features of dimension 6.\n", "The 6th simplicial complex has 750 edges with features of dimension 10.\n", "The 6th simplicial complex has 500 faces with features of dimension 7.\n" ] } ], "source": [ "i_complex = 6\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}.\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define neighborhood structures and lift into hypergraph domain. ##\n", "\n", "Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\\text{nodes} \\times n_\\text{edges}$.\n", "\n", "Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "hg_list = []\n", "incidence_1_list = []\n", "for simplex in simplexes:\n", " incidence_1 = simplex.incidence_matrix(rank=1, signed=False)\n", " hg = simplex.to_hypergraph()\n", " hg_list.append(hg)\n", "\n", "# Extract hypergraphs incident matrices from collected hypergraphs\n", "for hg in hg_list:\n", " incidence_1 = hg.incidence_matrix()\n", " incidence_1 = from_sparse(incidence_1)\n", " incidence_1_list.append(incidence_1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th hypergraph has an incidence matrix of shape torch.Size([252, 1250]).\n" ] } ], "source": [ "i_complex = 6\n", "print(\n", " f\"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}.\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train the Neural Network\n", "\n", "Define the network that initializes the base model and sets up the readout operation.\n", "Different downstream tasks might require different pooling procedures." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " \"\"\"Network class that initializes the AllSet model and readout layer.\n", "\n", " Base model parameters:\n", " ----------\n", " Reqired:\n", " in_channels : int\n", " Dimension of the input features.\n", " hidden_channels : int\n", " Dimension of the hidden features.\n", "\n", " Optitional:\n", " **kwargs : dict\n", " Additional arguments for the base model.\n", "\n", " Readout layer parameters:\n", " ----------\n", " out_channels : int\n", " Dimension of the output features.\n", " task_level : str\n", " Level of the task. Either \"graph\" or \"node\".\n", " \"\"\"\n", "\n", " def __init__(\n", " self, in_channels, hidden_channels, out_channels, task_level=\"graph\", **kwargs\n", " ):\n", " super().__init__()\n", "\n", " # Define the model\n", " self.base_model = HyperGAT(\n", " in_channels=in_channels, hidden_channels=hidden_channels, **kwargs\n", " )\n", "\n", " # Readout\n", " self.linear = torch.nn.Linear(hidden_channels, out_channels)\n", " self.out_pool = task_level == \"graph\"\n", "\n", " def forward(self, x_0, incidence_1):\n", " # Base model\n", " x_0, x_1 = self.base_model(x_0, incidence_1)\n", "\n", " # Pool over all nodes in the hypergraph\n", " x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0\n", "\n", " return self.linear(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Base model hyperparameters\n", "in_channels = x_0s[0].shape[1]\n", "hidden_channels = 32\n", "out_dim = 1\n", "n_layers = 3\n", "\n", "# Readout hyperparameters\n", "out_channels = 1\n", "task_level = \"graph\"\n", "\n", "\n", "model = Network(\n", " in_channels=in_channels,\n", " hidden_channels=hidden_channels,\n", " out_channels=out_channels,\n", " n_layers=n_layers,\n", " task_level=task_level,\n", ").to(device)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# in_channels = x_0s[0].shape[1]\n", "# hidden_channels = 32\n", "# out_dim = 1\n", "# n_layers = 3\n", "\n", "# # Define the model\n", "# model = HyperGAT(\n", "# in_channels=in_channels,\n", "# hidden_channels=hidden_channels,\n", "# out_channels=out_dim,\n", "# n_layers=n_layers\n", "# )\n", "# model = model.to(device)\n", "\n", "# Optimizer and loss\n", "opt = torch.optim.Adam(model.parameters(), lr=0.01)\n", "loss_fn = torch.nn.MSELoss()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "test_size = 0.2\n", "x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n", "incidence_1_train, incidence_1_test = train_test_split(\n", " incidence_1_list, test_size=test_size, shuffle=False\n", ")\n", "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.**" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", " return F.mse_loss(input, target, reduction=self.reduction)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 5 train_loss 7195.0287 test_loss: 19283.1447\n", "Epoch: 10 train_loss 1705.4479 test_loss: 4695.8421\n", "Epoch: 15 train_loss 2624.2060 test_loss: 3844.4079\n", "Epoch: 20 train_loss 6754.1770 test_loss: 4970.3517\n" ] } ], "source": [ "test_interval = 5\n", "num_epochs = 5\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", " for x_0, incidence_1, y in zip(x_0_train, incidence_1_train, y_train, strict=True):\n", " x_0 = torch.tensor(x_0)\n", " x_0, incidence_1, y = (\n", " x_0.float().to(device),\n", " incidence_1.float().to(device),\n", " torch.tensor(y, dtype=torch.float).to(device),\n", " )\n", " opt.zero_grad()\n", " # Extract edge_index from sparse incidence matrix\n", " # edge_index, _ = to_edge_index(incidence_1)\n", " y_hat = model(x_0, incidence_1)\n", " loss = loss_fn(y_hat, y)\n", "\n", " loss.backward()\n", " opt.step()\n", " epoch_loss.append(loss.item())\n", "\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", " train_loss = np.mean(epoch_loss)\n", "\n", " test_epoch_loss = []\n", " for x_0, incidence_1, y in zip(\n", " x_0_test, incidence_1_test, y_test, strict=True\n", " ):\n", " x_0 = torch.tensor(x_0)\n", " x_0, incidence_1, y = (\n", " x_0.float().to(device),\n", " incidence_1.float().to(device),\n", " torch.tensor(y, dtype=torch.float).to(device),\n", " )\n", " y_hat = model(x_0, incidence_1)\n", " loss = loss_fn(y_hat, y)\n", " test_epoch_loss.append(loss.item())\n", "\n", " print(\n", " f\"Epoch: {epoch_i} train_loss {train_loss:.4f} test_loss: {np.mean(test_epoch_loss):.4f}\",\n", " flush=True,\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": "torchg", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }