{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train a DHGCN TNN\n", "\n", "In this notebook, we will create and train a two-step message passing network in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:51.222779223Z", "start_time": "2023-06-01T16:14:49.575421023Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import toponetx.datasets as datasets\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "\n", "from topomodelx.nn.hypergraph.dhgcn import DHGCN\n", "from topomodelx.utils.sparse import from_sparse" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If GPU's are available, we will make use of them. Otherwise, this will run on CPU." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:51.959770754Z", "start_time": "2023-06-01T16:14:51.956096841Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "## Import data ##\n", "\n", "The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.\n", "\n", "We will also retrieve:\n", "- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input\n", "- the label associated to the hypergraph" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:53.022151550Z", "start_time": "2023-06-01T16:14:52.949636599Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading shrec 16 small dataset...\n", "\n", "done!\n" ] } ], "source": [ "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n", "\n", "shrec = {key: np.array(value) for key, value in shrec.items()}\n", "x_0s = shrec[\"node_feat\"]\n", "x_1s = shrec[\"edge_feat\"]\n", "x_2s = shrec[\"face_feat\"]\n", "\n", "ys = shrec[\"label\"]\n", "simplexes = shrec[\"complexes\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((100,), (100, 750, 10), (100, 500, 7), (100,), (100,))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_0s.shape, x_1s.shape, x_2s.shape, ys.shape, simplexes.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((252, 6), (750, 10), (500, 7))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_0s[4].shape, x_1s[0].shape, x_2s[0].shape\n", "# list(dir(simplexes[0]))[40:]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th simplicial complex has 252 nodes with features of dimension 6.\n", "The 6th simplicial complex has 750 edges with features of dimension 10.\n", "The 6th simplicial complex has 500 faces with features of dimension 7.\n" ] } ], "source": [ "i_complex = 6\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}.\"\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define neighborhood structures and lift into hypergraph domain. ##\n", "\n", "Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\\text{nodes} \\times n_\\text{edges}$.\n", "\n", "Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:53.022151550Z", "start_time": "2023-06-01T16:14:52.949636599Z" } }, "outputs": [], "source": [ "hg_list = []\n", "incidence_1_list = []\n", "for simplex in simplexes:\n", " incidence_1 = simplex.incidence_matrix(rank=1, signed=False)\n", " incidence_1 = from_sparse(incidence_1)\n", " incidence_1_list.append(incidence_1)\n", " hg = simplex.to_hypergraph()\n", " hg_list.append(hg)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th hypergraph has an incidence matrix of shape torch.Size([252, 750]).\n" ] } ], "source": [ "i_complex = 6\n", "print(\n", " f\"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}.\"\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Create the Neural Network\n", "\n", "Define the network that initializes the base model and sets up the readout operation.\n", "Different downstream tasks might require different pooling procedures.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " \"\"\"Network class that initializes the base model and readout layer.\n", "\n", " Base model parameters:\n", " ----------\n", " Reqired:\n", " in_channels : int\n", " Dimension of the input features.\n", " hidden_channels : int\n", " Dimension of the hidden features.\n", "\n", " Optitional:\n", " **kwargs : dict\n", " Additional arguments for the base model.\n", "\n", " Readout layer parameters:\n", " ----------\n", " out_channels : int\n", " Dimension of the output features.\n", " task_level : str\n", " Level of the task. Either \"graph\" or \"node\".\n", " \"\"\"\n", "\n", " def __init__(\n", " self, in_channels, hidden_channels, out_channels, task_level=\"graph\", **kwargs\n", " ):\n", " super().__init__()\n", "\n", " # Define the model\n", " self.base_model = DHGCN(\n", " in_channels=in_channels, hidden_channels=hidden_channels, **kwargs\n", " )\n", "\n", " # Readout\n", " self.linear = torch.nn.Linear(hidden_channels, out_channels)\n", " self.out_pool = task_level == \"graph\"\n", "\n", " def forward(self, x_0):\n", " # Base model\n", " x_0, x_1 = self.base_model(x_0)\n", "\n", " # Pool over all nodes in the hypergraph\n", " x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0\n", "\n", " return self.linear(x)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Base model hyperparameters\n", "in_channels = x_0s[0].shape[1]\n", "hidden_channels = 6\n", "n_layers = 2\n", "\n", "# Readout hyperparameters\n", "out_channels = 1\n", "task_level = \"graph\" if out_channels == 1 else \"node\"\n", "\n", "\n", "model = Network(\n", " in_channels=in_channels,\n", " hidden_channels=hidden_channels,\n", " out_channels=out_channels,\n", " n_layers=n_layers,\n", " task_level=task_level,\n", ").to(device)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train the Neural Network\n", "\n", "We specify the model, the loss, and an optimizer." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:58.153514385Z", "start_time": "2023-06-01T16:14:57.243596119Z" } }, "outputs": [], "source": [ "loss_fn = torch.nn.MSELoss()\n", "opt = torch.optim.Adam(model.parameters(), lr=0.1)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Split the dataset into train and test sets." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:14:59.046068930Z", "start_time": "2023-06-01T16:14:59.037648626Z" } }, "outputs": [], "source": [ "test_size = 0.2\n", "x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n", "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The following cell performs the training, looping over the network for a low amount of epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.**" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2023-06-01T16:15:01.683216142Z", "start_time": "2023-06-01T16:15:00.727075750Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", " return F.mse_loss(input, target, reduction=self.reduction)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 loss: 11062303.5101\n", "Test_loss: 135039.8750\n", "Epoch: 2 loss: 12746.0434\n", "Test_loss: 7911.8130\n", "Epoch: 3 loss: 1477.9950\n", "Test_loss: 228.0758\n", "Epoch: 4 loss: 702.3154\n", "Test_loss: 10.2958\n", "Epoch: 5 loss: 872.6976\n", "Test_loss: 1025.2914\n" ] } ], "source": [ "test_interval = 1\n", "num_epochs = 1\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", " for x_0, y in zip(x_0_train, y_train, strict=True):\n", " x_0 = torch.tensor(x_0)\n", " x_0, y = (\n", " x_0.float().to(device),\n", " torch.tensor(y, dtype=torch.float).to(device),\n", " )\n", " opt.zero_grad()\n", " y_hat = model(x_0)\n", " loss = loss_fn(y_hat, y)\n", "\n", " loss.backward()\n", " opt.step()\n", " epoch_loss.append(loss.item())\n", "\n", " print(\n", " f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}\",\n", " flush=True,\n", " )\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", " for x_0, y in zip(x_0_test, y_test, strict=True):\n", " x_0 = torch.tensor(x_0)\n", " x_0, y = (\n", " x_0.float().to(device),\n", " torch.tensor(y, dtype=torch.float).to(device),\n", " )\n", " y_hat = model(x_0)\n", " loss = loss_fn(y_hat, y)\n", "\n", " print(f\"Test_loss: {loss:.4f}\", flush=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.6 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 2 }