{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Simplex Convolutional Network (SCN) of Rank 2\n", "\n", "This notebook illustrates the SCN layer proposed in [Yang22c]_ for a simplicial complex of\n", "rank 2, that is for 0-cells (nodes), 1-cells (edges) and 2-cells (faces) only.\n", "\n", "References\n", "----------\n", ".. [YSB22] Ruochen Yang, Frederic Sala, and Paul Bogdan.\n", " Efficient Representation Learning for Higher-Order Data with \n", " Simplicial Complexes. In Bastian Rieck and Razvan Pascanu, editors, \n", " Proceedings of the First Learning on Graphs Conference, volume 198 \n", " of Proceedings of Machine Learning Research, pages 13:1–13:21. PMLR, \n", " 09–12 Dec 2022a. https://proceedings.mlr.press/v198/yang22a.html." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "import numpy as np\n", "import toponetx.datasets as datasets\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "\n", "from topomodelx.nn.simplicial.scn2 import SCN2\n", "from topomodelx.utils.sparse import from_sparse\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "## Import dataset ##\n", "\n", "According to the original paper, SCN is good at simplex classification. Thus, I chose shrec_16, a benchmark dataset for 3D mesh classification." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading shrec 16 small dataset...\n", "\n", "done!\n" ] } ], "source": [ "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n", "\n", "shrec = {key: np.array(value) for key, value in shrec.items()}\n", "x_0s = shrec[\"node_feat\"]\n", "x_1s = shrec[\"edge_feat\"]\n", "x_2s = shrec[\"face_feat\"]\n", "\n", "ys = shrec[\"label\"]\n", "ys = ys.reshape((100, 1))\n", "simplexes = shrec[\"complexes\"]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th simplicial complex has 252 nodes with features of dimension 6.\n", "The 6th simplicial complex has 750 edges with features of dimension 10.\n", "The 6th simplicial complex has 500 faces with features of dimension 7.\n" ] } ], "source": [ "i_complex = 6\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}.\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define neighborhood structures. ##\n", "\n", "Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the normalized Laplacian matrix on nodes, edges, and faces. We also convert the neighborhood structures to torch tensors." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "laplacian_0s = []\n", "laplacian_1s = []\n", "laplacian_2s = []\n", "for x in simplexes:\n", " laplacian_0 = x.normalized_laplacian_matrix(rank=0)\n", " laplacian_1 = x.normalized_laplacian_matrix(rank=1)\n", " laplacian_2 = x.normalized_laplacian_matrix(rank=2)\n", "\n", " laplacian_0 = from_sparse(laplacian_0)\n", " laplacian_1 = from_sparse(laplacian_1)\n", " laplacian_2 = from_sparse(laplacian_2)\n", "\n", " laplacian_0s.append(laplacian_0)\n", " laplacian_1s.append(laplacian_1)\n", " laplacian_2s.append(laplacian_2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train the Neural Network\n", "\n", "We specify the model with our pre-made neighborhood structures and specify an optimizer." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "in_channels_0 = x_0s[i_complex].shape[1]\n", "in_channels_1 = x_1s[i_complex].shape[1]\n", "in_channels_2 = x_2s[i_complex].shape[1]\n", "out_channels = 1" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " def __init__(\n", " self, in_channels_0, in_channels_1, in_channels_2, out_channels, n_layers=2\n", " ):\n", " super().__init__()\n", " self.base_model = SCN2(\n", " in_channels_0=in_channels_0,\n", " in_channels_1=in_channels_1,\n", " in_channels_2=in_channels_2,\n", " n_layers=n_layers,\n", " )\n", " self.lin_0 = torch.nn.Linear(in_channels_0, out_channels)\n", " self.lin_1 = torch.nn.Linear(in_channels_1, out_channels)\n", " self.lin_2 = torch.nn.Linear(in_channels_2, out_channels)\n", "\n", " def forward(self, x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2):\n", " x_0, x_1, x_2 = self.base_model(\n", " x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2\n", " )\n", "\n", " x_0 = self.lin_0(x_0)\n", " x_1 = self.lin_1(x_1)\n", " x_2 = self.lin_2(x_2)\n", "\n", " # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n", " two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n", " two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n", " one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n", " one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n", " zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n", " zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n", " # Return the sum of the averages\n", " return (\n", " two_dimensional_cells_mean\n", " + one_dimensional_cells_mean\n", " + zero_dimensional_cells_mean\n", " )" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "n_layers = 2\n", "model = Network(\n", " in_channels_0, in_channels_1, in_channels_2, out_channels, n_layers=n_layers\n", ")\n", "model = model.to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n", "loss_fn = torch.nn.MSELoss()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "test_size = 0.2\n", "x_0s_train, x_0s_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n", "x_1s_train, x_1s_test = train_test_split(x_1s, test_size=test_size, shuffle=False)\n", "x_2s_train, x_2s_test = train_test_split(x_2s, test_size=test_size, shuffle=False)\n", "\n", "laplacian_0s_train, laplacian_0s_test = train_test_split(\n", " laplacian_0s, test_size=test_size, shuffle=False\n", ")\n", "laplacian_1s_train, laplacian_1s_test = train_test_split(\n", " laplacian_1s, test_size=test_size, shuffle=False\n", ")\n", "laplacian_2s_train, laplacian_2s_test = train_test_split(\n", " laplacian_2s, test_size=test_size, shuffle=False\n", ")\n", "\n", "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following cell performs the training, looping over the network for a low number of epochs." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 loss: 106.0943\n", "Epoch: 2 loss: 85.9616\n", "Epoch: 3 loss: 83.0319\n", "Epoch: 4 loss: 80.9816\n", "Epoch: 5 loss: 79.5344\n", "Epoch: 6 loss: 78.9647\n", "Epoch: 7 loss: 78.0454\n", "Epoch: 8 loss: 77.2352\n", "Epoch: 9 loss: 77.0969\n", "Epoch: 10 loss: 76.4979\n", "Test_loss: 27.4806\n", "Epoch: 11 loss: 76.7194\n", "Epoch: 12 loss: 76.1580\n", "Epoch: 13 loss: 78.0519\n", "Epoch: 14 loss: 75.5705\n", "Epoch: 15 loss: 76.1465\n", "Epoch: 16 loss: 76.5432\n", "Epoch: 17 loss: 78.0466\n", "Epoch: 18 loss: 77.0446\n", "Epoch: 19 loss: 80.5840\n", "Epoch: 20 loss: 75.3278\n", "Test_loss: 35.2363\n", "Epoch: 21 loss: 75.6831\n", "Epoch: 22 loss: 74.5437\n", "Epoch: 23 loss: 75.6770\n", "Epoch: 24 loss: 78.6057\n", "Epoch: 25 loss: 75.3240\n", "Epoch: 26 loss: 74.4333\n", "Epoch: 27 loss: 75.4743\n", "Epoch: 28 loss: 76.0289\n", "Epoch: 29 loss: 75.6414\n", "Epoch: 30 loss: 76.2671\n", "Test_loss: 20.8572\n", "Epoch: 31 loss: 75.7918\n", "Epoch: 32 loss: 74.1627\n", "Epoch: 33 loss: 73.6667\n", "Epoch: 34 loss: 72.9490\n", "Epoch: 35 loss: 72.8358\n", "Epoch: 36 loss: 73.0702\n", "Epoch: 37 loss: 73.5069\n", "Epoch: 38 loss: 73.7774\n", "Epoch: 39 loss: 73.2788\n", "Epoch: 40 loss: 73.7978\n", "Test_loss: 21.3558\n", "Epoch: 41 loss: 74.9096\n", "Epoch: 42 loss: 73.2390\n", "Epoch: 43 loss: 72.4291\n", "Epoch: 44 loss: 73.3779\n", "Epoch: 45 loss: 72.3256\n", "Epoch: 46 loss: 72.9241\n", "Epoch: 47 loss: 72.3715\n", "Epoch: 48 loss: 72.1551\n", "Epoch: 49 loss: 72.7596\n", "Epoch: 50 loss: 72.3155\n", "Test_loss: 13.5706\n", "Epoch: 51 loss: 73.5445\n", "Epoch: 52 loss: 72.3427\n", "Epoch: 53 loss: 74.1711\n", "Epoch: 54 loss: 72.1126\n", "Epoch: 55 loss: 71.3567\n", "Epoch: 56 loss: 69.5716\n", "Epoch: 57 loss: 70.7865\n", "Epoch: 58 loss: 70.4044\n", "Epoch: 59 loss: 69.9258\n", "Epoch: 60 loss: 69.7257\n", "Test_loss: 11.6862\n", "Epoch: 61 loss: 68.8615\n", "Epoch: 62 loss: 69.6709\n", "Epoch: 63 loss: 69.1890\n", "Epoch: 64 loss: 70.6630\n", "Epoch: 65 loss: 68.8225\n", "Epoch: 66 loss: 68.8715\n", "Epoch: 67 loss: 68.1793\n", "Epoch: 68 loss: 68.7412\n", "Epoch: 69 loss: 71.5032\n", "Epoch: 70 loss: 70.2721\n", "Test_loss: 6.6078\n", "Epoch: 71 loss: 68.2701\n", "Epoch: 72 loss: 69.6752\n", "Epoch: 73 loss: 64.3450\n", "Epoch: 74 loss: 62.4395\n", "Epoch: 75 loss: 62.2776\n", "Epoch: 76 loss: 67.0761\n", "Epoch: 77 loss: 63.7860\n", "Epoch: 78 loss: 60.9918\n", "Epoch: 79 loss: 60.4742\n", "Epoch: 80 loss: 60.2551\n", "Test_loss: 0.2153\n", "Epoch: 81 loss: 60.0425\n", "Epoch: 82 loss: 59.4667\n", "Epoch: 83 loss: 58.2843\n", "Epoch: 84 loss: 57.9831\n", "Epoch: 85 loss: 57.4091\n", "Epoch: 86 loss: 56.9356\n", "Epoch: 87 loss: 57.3874\n", "Epoch: 88 loss: 57.6860\n", "Epoch: 89 loss: 56.4505\n", "Epoch: 90 loss: 56.2496\n", "Test_loss: 0.6815\n", "Epoch: 91 loss: 56.8954\n", "Epoch: 92 loss: 55.0596\n", "Epoch: 93 loss: 55.2672\n", "Epoch: 94 loss: 55.1773\n", "Epoch: 95 loss: 55.1011\n", "Epoch: 96 loss: 54.5242\n", "Epoch: 97 loss: 54.0988\n", "Epoch: 98 loss: 54.5479\n", "Epoch: 99 loss: 54.0969\n", "Epoch: 100 loss: 54.7246\n", "Test_loss: 0.0298\n" ] } ], "source": [ "test_interval = 10\n", "num_epochs = 100\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", " for x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2, y in zip(\n", " x_0s_train,\n", " x_1s_train,\n", " x_2s_train,\n", " laplacian_0s_train,\n", " laplacian_1s_train,\n", " laplacian_2s_train,\n", " y_train,\n", " strict=False,\n", " ):\n", " x_0, x_1, x_2, y = (\n", " torch.tensor(x_0).float().to(device),\n", " torch.tensor(x_1).float().to(device),\n", " torch.tensor(x_2).float().to(device),\n", " torch.tensor(y).float().to(device),\n", " )\n", " laplacian_0, laplacian_1, laplacian_2 = (\n", " laplacian_0.float().to(device),\n", " laplacian_1.float().to(device),\n", " laplacian_2.float().to(device),\n", " )\n", " optimizer.zero_grad()\n", " y_hat = model(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2)\n", " loss = loss_fn(y_hat, y)\n", " loss.backward()\n", " optimizer.step()\n", " epoch_loss.append(loss.item())\n", " print(\n", " f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}\",\n", " flush=True,\n", " )\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", " for x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2, y in zip(\n", " x_0s_test,\n", " x_1s_test,\n", " x_2s_test,\n", " laplacian_0s_test,\n", " laplacian_1s_test,\n", " laplacian_2s_test,\n", " y_test,\n", " strict=False,\n", " ):\n", " x_0, x_1, x_2, y = (\n", " torch.tensor(x_0).float().to(device),\n", " torch.tensor(x_1).float().to(device),\n", " torch.tensor(x_2).float().to(device),\n", " torch.tensor(y).float().to(device),\n", " )\n", " laplacian_0, laplacian_1, laplacian_2 = (\n", " laplacian_0.float().to(device),\n", " laplacian_1.float().to(device),\n", " laplacian_2.float().to(device),\n", " )\n", " y_hat = model(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2)\n", " test_loss = loss_fn(y_hat, y)\n", " print(f\"Test_loss: {test_loss:.4f}\", flush=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 4 }