{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train an All-Set-Transformer TNN\n", "\n", "In this notebook, we will create and train a two-step message passing network named AllSetTransformer (Chien et al., [2021](https://arxiv.org/abs/2106.13264)) in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. \n", "\n", "Following the \"awesome-tnns\" [github repo.](https://github.com/awesome-tnns/awesome-tnns/blob/main/Hypergraphs.md)\n", "\n", "🟧 $\\quad m_{\\rightarrow z}^{(\\rightarrow 1)} = AGG_{y \\in \\mathcal{B}(z)} (h_y^{t, (0)}, h_z^{t,(1)}) \\quad \\text{with attention}$ \n", "\n", "🟦 $\\quad h_z^{t+1,(1)} = \\text{LN}(m_{\\rightarrow z}^{(\\rightarrow 1)} + \\text{MLP}(m_{\\rightarrow z}^{(\\rightarrow 1)} ))$ \n", "\n", "Edge to vertex: \n", "\n", "🟧 $\\quad m_{\\rightarrow x}^{(\\rightarrow 0)} = AGG_{z \\in \\mathcal{C}(x)} (h_z^{t+1,(1)}, h_x^{t,(0)}) \\quad \\text{with attention}$ \n", "\n", "🟦 $\\quad h_x^{t+1,(0)} = \\text{LN}(m_{\\rightarrow x}^{(\\rightarrow 0)} + \\text{MLP}(m_{\\rightarrow x}^{(\\rightarrow 0)} ))$\n", "\n", "### Additional theoretical clarifications\n", "\n", "Given a hypergraph $G=(\\mathcal{V}, \\mathcal{E})$, let $\\textbf{X} \\in \\mathbb{R}^{|\\mathcal{V}| \\times F}$ and $\\textbf{Z} \\in \\mathbb{R}^{|\\mathcal{E}| \\times F'}$ denote the hidden node and hyperedge representations, respectively. Additionally, define $V_{e, \\textbf{X}} = \\{\\textbf{X}_{u,:}: u \\in e\\}$ as the multiset of hidden node representations in the hyperedge $e$ and $E_{v, \\textbf{Z}} = \\{\\textbf{Z}_{e,:}: v \\in e\\}$ as the multiset of hidden representations of hyperedges containing $v$.\n", "\n", "\\\n", "In this setting, the two general update rules that AllSet's framework puts in place in each layer are:\n", "\n", "🔷 $\\textbf{Z}_{e,:}^{(t+1)} = f_{\\mathcal{V} \\rightarrow \\mathcal{E}}(V_{e, \\textbf{X}^{(t)}}; \\textbf{Z}_{e,:}^{(t)})$\n", "\n", "🔷 $\\textbf{X}_{v,:}^{(t+1)} = f_{\\mathcal{E} \\rightarrow \\mathcal{V}}(E_{v, \\textbf{Z}^{(t+1)}}; \\textbf{X}_{v,:}^{(t)})$\n", "\n", "in which $f_{\\mathcal{V} \\rightarrow \\mathcal{E}}$ and $f_{\\mathcal{E} \\rightarrow \\mathcal{V}}$ are two permutation invariant functions with respect to their first input. The matrices $\\textbf{Z}_{e,:}^{(0)}$ and $\\textbf{X}_{v,:}^{(0)}$ are initialized with the hyperedge and node features respectively, if available, otherwise they are set to be all-zero matrices.\n", "\n", "In the practical implementation of the model, $f_{\\mathcal{V} \\rightarrow \\mathcal{E}}$ and $f_{\\mathcal{E} \\rightarrow \\mathcal{V}}$ are parametrized and $learnt$ for each dataset and task, and the information of their second argument is not utilized. The option achieving the best results makes use of attention-based layers, giving rise to the so-called AllSetTransformer architecture.\n", "\n", "\\\n", "We now dive deep into the details of AllSetTransformer, describing how the update functions $f_{\\mathcal{V} \\rightarrow \\mathcal{E}}$ and $f_{\\mathcal{E} \\rightarrow \\mathcal{V}}$ are iteratively defined.\n", "Their input is a matrix $\\textbf{S} \\in \\mathbb{R}^{|S| \\times F}$ which corresponds the multiset of $F$-dimensional feature vectors:\n", "\n", "1️⃣ $\\textbf{K}^{(i)} = \\text{MLP}^{K, i}(\\textbf{S}), \\textbf{V}^{(i)} = \\text{MLP}^{V, i}(\\textbf{S})$, where $i \\in \\{1, ..., h\\},$\n", "\n", "2️⃣ $ \\textbf{O}^{(i)} = \\omega (\\theta^{(i)}(\\textbf{K}^{(i)})^{T}) \\textbf{V}^{(i)},$ \n", "\n", "3️⃣ $\\theta \\overset{\\Delta}{=} \\mathbin\\Vert_{i=1}^{h} \\theta^{(i)}, $\n", "\n", "4️⃣ $ \\text{MH}_{h, \\omega}(\\theta, \\textbf{S}, \\textbf{S}) = \\mathbin\\Vert_{i=1}^{h} \\textbf{O}^{(i)}, $\n", "\n", "5️⃣ $ \\textbf{Y} = \\text{LN} (\\theta + \\text{MH}_{h, \\omega}(\\theta, \\textbf{S}, \\textbf{S})), $\n", "\n", "6️⃣ $f_{\\mathcal{V} \\rightarrow \\mathcal{E}}(\\textbf{S}) = f_{\\mathcal{E} \\rightarrow \\mathcal{V}}(\\textbf{S}) = \\text{LN} (\\textbf{Y} + \\text{MLP}(\\textbf{Y}))$.\n", "\n", "\\\n", "\n", "The elements and operations used in these steps are defined as follows:\n", "\n", "🔶 $\\text{LN}$ means layer normalization (Ba et al., [2016](https://arxiv.org/abs/1607.06450)),\n", "\n", "🔶 $\\mathbin\\Vert$ represents concatenation,\n", "\n", "🔶 $\\theta \\in \\mathbb{R}^{1 \\times hF_{h}}$ is a learnable weight,\n", "\n", "🔶 $\\text{MH}_{h, \\omega}$ denotes a multihead attention mechanism with $h$ heads and activation function $\\omega$ (Vaswani et al., [2017](https://proceedings.neurips.cc/paper_files/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html)),\n", "\n", "🔶 all $\\text{MLP}$ modules are multi-layer perceptrons that operate row-wise, so they are applied identically and independently to each multiset element of $\\textbf{S}$.\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:51.222779223Z", "start_time": "2023-06-01T16:14:49.575421023Z" } }, "outputs": [], "source": [ "\"\"\"\n", "This module contains the AllSetTransformer class for hypergraph-based neural networks.\n", "\n", "The AllSet class implements a specific hypergraph-based neural network architecture\n", "used for solving certain types of problems.\n", "\n", "Author: Your Name\n", "\n", "\"\"\"\n", "\n", "import numpy as np\n", "import torch\n", "import torch_geometric.datasets as geom_datasets\n", "from torch_geometric.utils import to_undirected\n", "\n", "from topomodelx.nn.hypergraph.allset_transformer import AllSetTransformer\n", "\n", "torch.manual_seed(0)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If GPU's are available, we will make use of them. Otherwise, this will run on CPU." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:51.959770754Z", "start_time": "2023-06-01T16:14:51.956096841Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "## Import data ##\n", "\n", "The first step is to import the dataset, Cora, a benchmark classification datase. We then lift the graph into our domain of choice, a hypergraph.\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:53.022151550Z", "start_time": "2023-06-01T16:14:52.949636599Z" } }, "outputs": [], "source": [ "cora = geom_datasets.Planetoid(root=\"tmp/\", name=\"cora\")\n", "data = cora.data\n", "\n", "x_0s = data.x\n", "y = data.y\n", "edge_index = data.edge_index\n", "\n", "train_mask = data.train_mask\n", "val_mask = data.val_mask\n", "test_mask = data.test_mask" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define neighborhood structures and lift into hypergraph domain. ##\n", "\n", "Now we retrieve the neighborhood structure (i.e. their representative matrice) that we will use to send messges from node to hyperedges. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\\text{nodes} \\times n_\\text{edges}$.\n", "\n", "In citation Cora dataset we lift graph structure to the hypergraph domain by creating hyperedges from 1-hop graph neighbourhood of each node. \n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:53.022151550Z", "start_time": "2023-06-01T16:14:52.949636599Z" } }, "outputs": [], "source": [ "# Ensure the graph is undirected (optional but often useful for one-hop neighborhoods).\n", "edge_index = to_undirected(edge_index)\n", "\n", "# Create a list of one-hop neighborhoods for each node.\n", "one_hop_neighborhoods = []\n", "for node in range(data.num_nodes):\n", " # Get the one-hop neighbors of the current node.\n", " neighbors = data.edge_index[1, data.edge_index[0] == node]\n", "\n", " # Append the neighbors to the list of one-hop neighborhoods.\n", " one_hop_neighborhoods.append(neighbors.numpy())\n", "\n", "# Detect and eliminate duplicate hyperedges.\n", "unique_hyperedges = set()\n", "hyperedges = []\n", "for neighborhood in one_hop_neighborhoods:\n", " # Sort the neighborhood to ensure consistent comparison.\n", " neighborhood = tuple(sorted(neighborhood))\n", " if neighborhood not in unique_hyperedges:\n", " hyperedges.append(list(neighborhood))\n", " unique_hyperedges.add(neighborhood)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Additionally we print the statictis associated with obtained incidence matrix" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# # Calculate hyperedge statistics.\n", "# hyperedge_sizes = [len(he) for he in hyperedges]\n", "# min_size = min(hyperedge_sizes)\n", "# max_size = max(hyperedge_sizes)\n", "# mean_size = np.mean(hyperedge_sizes)\n", "# median_size = np.median(hyperedge_sizes)\n", "# std_size = np.std(hyperedge_sizes)\n", "# num_single_node_hyperedges = sum(np.array(hyperedge_sizes) == 1)\n", "\n", "# # Print the hyperedge statistics.\n", "# print(f\"Hyperedge statistics: \")\n", "# print(\"Number of hyperedges without duplicated hyperedges\", len(hyperedges))\n", "# print(f\"min = {min_size}, \")\n", "# print(f\"max = {max_size}, \")\n", "# print(f\"mean = {mean_size}, \")\n", "# print(f\"median = {median_size}, \")\n", "# print(f\"std = {std_size}, \")\n", "# print(f\"Number of hyperedges with size equal to one = {num_single_node_hyperedges}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Construct incidence matrix" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "max_edges = len(hyperedges)\n", "incidence_1 = np.zeros((x_0s.shape[0], max_edges))\n", "for col, neighibourhood in enumerate(hyperedges):\n", " for row in neighibourhood:\n", " incidence_1[row, col] = 1\n", "\n", "assert all(incidence_1.sum(0) > 0) is True, \"Some hyperedges are empty\"\n", "assert all(incidence_1.sum(1) > 0) is True, \"Some nodes are not in any hyperedges\"\n", "incidence_1 = torch.Tensor(incidence_1).to_sparse_coo()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Create the Neural Network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the network that initializes the base model and sets up the readout operation.\n", "Different downstream tasks might require different pooling procedures.\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " \"\"\"Network class that initializes the AllSet model and readout layer.\n", "\n", " Base model parameters:\n", " ----------\n", " Reqired:\n", " in_channels : int\n", " Dimension of the input features.\n", " hidden_channels : int\n", " Dimension of the hidden features.\n", "\n", " Optitional:\n", " **kwargs : dict\n", " Additional arguments for the base model.\n", "\n", " Readout layer parameters:\n", " ----------\n", " out_channels : int\n", " Dimension of the output features.\n", " task_level : str\n", " Level of the task. Either \"graph\" or \"node\".\n", " \"\"\"\n", "\n", " def __init__(\n", " self, in_channels, hidden_channels, out_channels, task_level=\"graph\", **kwargs\n", " ):\n", " super().__init__()\n", "\n", " # Define the model\n", " self.base_model = AllSetTransformer(\n", " in_channels=in_channels, hidden_channels=hidden_channels, **kwargs\n", " )\n", "\n", " # Readout\n", " self.linear = torch.nn.Linear(hidden_channels, out_channels)\n", " self.out_pool = task_level == \"graph\"\n", "\n", " def forward(self, x_0, incidence_1):\n", " # Base model\n", " x_0, x_1 = self.base_model(x_0, incidence_1)\n", "\n", " # Pool over all nodes in the hypergraph\n", " x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0\n", "\n", " return self.linear(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the model" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Base model hyperparameters\n", "in_channels = x_0s.shape[1]\n", "hidden_channels = 128\n", "\n", "heads = 4\n", "n_layers = 1\n", "mlp_num_layers = 2\n", "\n", "# Readout hyperparameters\n", "out_channels = torch.unique(y).shape[0]\n", "task_level = \"graph\" if out_channels == 1 else \"node\"\n", "\n", "\n", "model = Network(\n", " in_channels=in_channels,\n", " hidden_channels=hidden_channels,\n", " out_channels=out_channels,\n", " n_layers=n_layers,\n", " mlp_num_layers=mlp_num_layers,\n", " task_level=task_level,\n", ").to(device)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train the Neural Network\n", "\n", "We specify the model, the loss, and an optimizer." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Optimizer and loss\n", "opt = torch.optim.Adam(model.parameters(), lr=0.01)\n", "\n", "# Categorial cross-entropy loss\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "\n", "\n", "# Accuracy\n", "def acc_fn(y, y_hat):\n", " return (y == y_hat).float().mean()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:59.046068930Z", "start_time": "2023-06-01T16:14:59.037648626Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_909051/276484184.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " x_0s = torch.tensor(x_0s)\n", "/tmp/ipykernel_909051/276484184.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " torch.tensor(y, dtype=torch.long).to(device),\n" ] } ], "source": [ "x_0s = torch.tensor(x_0s)\n", "x_0s, incidence_1, y = (\n", " x_0s.float().to(device),\n", " incidence_1.float().to(device),\n", " torch.tensor(y, dtype=torch.long).to(device),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 10 \n", "Train_loss: 0.9825, acc: 1.0000\n", "Val_loss: 0.9895, Val_acc: 0.7380\n", "Test_loss: 0.8713, Test_acc: 0.7740\n" ] } ], "source": [ "test_interval = 1\n", "num_epochs = 1\n", "\n", "epoch_loss = []\n", "for epoch_i in range(1, num_epochs + 1):\n", " model.train()\n", "\n", " opt.zero_grad()\n", "\n", " # Extract edge_index from sparse incidence matrix\n", " y_hat = model(x_0s, incidence_1)\n", " loss = loss_fn(y_hat[train_mask], y[train_mask])\n", "\n", " loss.backward()\n", " opt.step()\n", " epoch_loss.append(loss.item())\n", "\n", " if epoch_i % test_interval == 0:\n", " model.eval()\n", " y_hat = model(x_0s, incidence_1)\n", "\n", " loss = loss_fn(y_hat[train_mask], y[train_mask])\n", " print(f\"Epoch: {epoch_i} \")\n", " print(\n", " f\"Train_loss: {np.mean(epoch_loss):.4f}, acc: {acc_fn(y_hat[train_mask].argmax(1), y[train_mask]):.4f}\",\n", " flush=True,\n", " )\n", "\n", " loss = loss_fn(y_hat[val_mask], y[val_mask])\n", "\n", " print(\n", " f\"Val_loss: {loss:.4f}, Val_acc: {acc_fn(y_hat[val_mask].argmax(1), y[val_mask]):.4f}\",\n", " flush=True,\n", " )\n", "\n", " loss = loss_fn(y_hat[test_mask], y[test_mask])\n", " print(\n", " f\"Test_loss: {loss:.4f}, Test_acc: {acc_fn(y_hat[test_mask].argmax(1), y[test_mask]):.4f}\",\n", " flush=True,\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 2 }