{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "gGVIRSAKNH1n" }, "source": [ "# Train a CW Network (CWN)\n", "\n", "We create and train a specific version of the CWN originally proposed in [Bodnar et. al : Weisfeiler and Lehman Go Cellular: CW Networks (2021)](https://arxiv.org/pdf/2106.12575.pdf).\n", "\n", "### The Neural Network:\n", "\n", "The equations for a single layer of this neural network are given by:\n", "\n", "🟥 $\\quad m_{y \\rightarrow \\{z\\} \\rightarrow x}^{(r \\rightarrow r' \\rightarrow r)} = M_{\\mathcal{L}\\uparrow}(h_x^{t,(r)}, h_y^{t,(r)}, h_z^{t,(r')})$\n", "\n", "🟥 $\\quad m_{y \\rightarrow x}^{(r'' \\rightarrow r)} = M_{\\mathcal{B}}(h_x^{t,(r)}, h_y^{t,(r'')})$\n", "\n", "🟧 $\\quad m_x^{(r'' \\rightarrow r)} = AGG_{y \\in \\mathcal{B}(x)} m_{y \\rightarrow x}^{(r'' \\rightarrow r)}$\n", "\n", "🟧 $\\quad m_x^{(r \\rightarrow r' \\rightarrow r)} = AGG_{y \\in \\mathcal{L}(x)} m_{y \\rightarrow \\{z\\} \\rightarrow x}^{(r \\rightarrow r' \\rightarrow r)}$\n", "\n", "🟩 $\\quad m_x^{(r)} = AGG_{\\mathcal{N}\\_k \\in \\mathcal{N}} (m_x^k)$\n", "\n", "🟦 $\\quad h_x^{t+1,(r)} = U\\left(h_x^{t,(r)}, m_x^{(r)}\\right)$\n", "\n", "Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).\n", "\n", "### The Task:\n", "\n", "We train this model to perform entire complex classification on a small version of [shrec16](http://shapenet.cs.stanford.edu/shrec16/)." ] }, { "cell_type": "markdown", "metadata": { "id": "_-rhjPpLNH1t" }, "source": [ "# Set-up\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:06:36.009880829Z", "start_time": "2023-05-31T09:06:34.285257706Z" }, "id": "h-kcMSPGNH1v" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import toponetx.datasets as datasets\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "\n", "from topomodelx.nn.cell.cwn import CWN\n", "from topomodelx.utils.sparse import from_sparse\n", "\n", "torch.manual_seed(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "O5V6Ly0qNH1y" }, "source": [ "If GPU's are available, we will make use of them. Otherwise, this will run on CPU." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:53.006542411Z", "start_time": "2023-05-31T09:13:52.963074076Z" }, "colab": { "base_uri": "https://localhost:8080/" }, "id": "v9Vct-axNH1y", "outputId": "99dddc1a-e0e8-4644-d2b8-21b1fc62c2cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "zDOonXwjNH10" }, "source": [ "# Pre-processing\n", "\n", "## Import data ##\n", "\n", "The first step is to import the dataset, shrec16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a cell complex.\n", "\n", "We also retrieve:\n", "- input signals `x_0`, `x_1`, `x_2` on the nodes (0-cells), edges (1-cells), and faces (2-cells) for each complex: these will be the model's inputs,\n", "- a scalar classification label `y` associated to the cell complex." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VEkEnk3XNH11", "outputId": "270b0d6f-0caa-4951-d829-fa496b9842e1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading shrec 16 small dataset...\n", "\n", "done!\n" ] } ], "source": [ "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n", "\n", "shrec = {key: np.array(value) for key, value in shrec.items()}\n", "x_0s = shrec[\"node_feat\"]\n", "x_1s = shrec[\"edge_feat\"]\n", "x_2s = shrec[\"face_feat\"]\n", "\n", "ys = shrec[\"label\"]\n", "simplexes = shrec[\"complexes\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ximOSpO0NH12", "outputId": "0de63af4-75b9-4ec2-d863-14b8c46f9efb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th simplicial complex has 252 nodes with features of dimension 6.\n", "The 6th simplicial complex has 750 edges with features of dimension 10.\n", "The 6th simplicial complex has 500 faces with features of dimension 7.\n" ] } ], "source": [ "i_complex = 6\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes \"\n", " f\"with features of dimension {x_0s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges \"\n", " f\"with features of dimension {x_1s[i_complex].shape[1]}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces \"\n", " f\"with features of dimension {x_2s[i_complex].shape[1]}.\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "TazhJt0vNH12" }, "source": [ "## Lift into cell complex domain and define neighborhood structures ##\n", "\n", "We lift each simplicial complex into a cell complex.\n", "\n", "Then, we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on each cell complex. In the case of this architecture, we need the upper adjacency matrix $A_{\\uparrow, r}$, the coboundary matrix $B_r^{\\intercal}$, and the boundary matrix $B_{r+1}$." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "-JwXTU8FNH13" }, "outputs": [], "source": [ "cc_list = []\n", "incidence_2_list = []\n", "adjacency_1_list = []\n", "incidence_1_t_list = []\n", "\n", "for simplex in simplexes:\n", " cell_complex = simplex.to_cell_complex()\n", " cc_list.append(cell_complex)\n", "\n", " incidence_2 = cell_complex.incidence_matrix(rank=2)\n", " adjacency_1 = cell_complex.adjacency_matrix(rank=1)\n", " incidence_1_t = cell_complex.incidence_matrix(rank=1).T\n", "\n", " incidence_2 = from_sparse(incidence_2)\n", " adjacency_1 = from_sparse(adjacency_1)\n", " incidence_1_t = from_sparse(incidence_1_t)\n", "\n", " incidence_2_list.append(incidence_2)\n", " adjacency_1_list.append(adjacency_1)\n", " incidence_1_t_list.append(incidence_1_t)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.832585216Z", "start_time": "2023-05-31T09:13:55.815448708Z" }, "colab": { "base_uri": "https://localhost:8080/" }, "id": "4s6EISEdNH14", "outputId": "736a4960-da74-42ab-d0dd-cc9f57dc9c12" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The 6th cell complex has an adjacency_1 matrix of shape torch.Size([750, 750]).\n", "The 6th cell complex has an incidence_2 matrix of shape torch.Size([750, 500]).\n", "The 6th cell complex has an incidence_1_t matrix of shape torch.Size([750, 252]).\n" ] } ], "source": [ "i_complex = 6\n", "\n", "print(\n", " f\"The {i_complex}th cell complex has an adjacency_1 matrix \"\n", " f\"of shape {adjacency_1_list[i_complex].shape}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th cell complex has an incidence_2 matrix \"\n", " f\"of shape {incidence_2_list[i_complex].shape}.\"\n", ")\n", "print(\n", " f\"The {i_complex}th cell complex has an incidence_1_t matrix \"\n", " f\"of shape {incidence_1_t_list[i_complex].shape}.\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "kV2St1ehNH14" }, "source": [ "# Create the Neural Network\n", "\n", "Using the CWNLayer class, we create a neural network with stacked layers." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " def __init__(\n", " self,\n", " in_channels_0,\n", " in_channels_1,\n", " in_channels_2,\n", " hid_channels=16,\n", " num_classes=1,\n", " n_layers=2,\n", " ):\n", " super().__init__()\n", " self.base_model = CWN(\n", " in_channels_0,\n", " in_channels_1,\n", " in_channels_2,\n", " hid_channels=hid_channels,\n", " n_layers=n_layers,\n", " )\n", " self.lin_0 = torch.nn.Linear(hid_channels, num_classes)\n", " self.lin_1 = torch.nn.Linear(hid_channels, num_classes)\n", " self.lin_2 = torch.nn.Linear(hid_channels, num_classes)\n", "\n", " def forward(\n", " self,\n", " x_0,\n", " x_1,\n", " x_2,\n", " adjacency_1,\n", " incidence_2,\n", " incidence_1_t,\n", " ):\n", " x_0, x_1, x_2 = self.base_model(\n", " x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t\n", " )\n", " x_0 = self.lin_0(x_0)\n", " x_1 = self.lin_1(x_1)\n", " x_2 = self.lin_2(x_2)\n", "\n", " # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n", " two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n", " two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n", "\n", " one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n", " one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n", "\n", " zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n", " zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n", "\n", " # Return the sum of the averages\n", " return (\n", " two_dimensional_cells_mean\n", " + one_dimensional_cells_mean\n", " + zero_dimensional_cells_mean\n", " )" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:56.672913561Z", "start_time": "2023-05-31T09:13:56.667986426Z" }, "colab": { "base_uri": "https://localhost:8080/" }, "id": "lJi03cUINH15", "outputId": "da159ed1-7d4b-4635-b690-12e5c3aac16d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The dimensions of input features on nodes, edges and faces are 6, 10 and 7, respectively.\n" ] } ], "source": [ "in_channels_0 = x_0s[0].shape[-1]\n", "in_channels_1 = x_1s[0].shape[-1]\n", "in_channels_2 = x_2s[0].shape[-1]\n", "\n", "print(\n", " f\"The dimensions of input features on nodes, edges and faces are \"\n", " f\"{in_channels_0}, {in_channels_1} and {in_channels_2}, respectively.\"\n", ")\n", "model = Network(\n", " in_channels_0,\n", " in_channels_1,\n", " in_channels_2,\n", " hid_channels=16,\n", " num_classes=1,\n", " n_layers=2,\n", ")\n", "model = model.to(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "K-LaZeY3NH16" }, "source": [ "# Train the Neural Network\n", "\n", "We instantiate a model, specify an optimizer, define a loss function." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:40.411845803Z", "start_time": "2023-05-31T09:19:40.408861921Z" }, "id": "YFC6ZicmNH17" }, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "criterion = torch.nn.MSELoss()" ] }, { "cell_type": "markdown", "metadata": { "id": "3yENiotDNH17" }, "source": [ "We split the dataset into train and test sets." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:41.150933630Z", "start_time": "2023-05-31T09:19:41.146986990Z" }, "id": "msAKrXJANH17" }, "outputs": [], "source": [ "test_size = 0.2\n", "\n", "x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n", "x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)\n", "x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)\n", "\n", "adjacency_1_train, adjacency_1_test = train_test_split(\n", " adjacency_1_list, test_size=test_size, shuffle=False\n", ")\n", "incidence_2_train, incidence_2_test = train_test_split(\n", " incidence_2_list, test_size=test_size, shuffle=False\n", ")\n", "incidence_1_t_train, incidence_1_t_test = train_test_split(\n", " incidence_1_t_list, test_size=test_size, shuffle=False\n", ")\n", "\n", "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "nOMyVj-XNH18" }, "source": [ "We train the CWN using low amount of epochs: we keep training minimal for the purpose of rapid testing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.**" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:42.918836083Z", "start_time": "2023-05-31T09:19:42.114801039Z" }, "colab": { "base_uri": "https://localhost:8080/" }, "id": "UT-pR2atNH18", "outputId": "8e506805-8aa7-482a-8c66-9708a41478ac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2, Train Loss: 83.8053 Test Loss: 73.7517\n", "Epoch:4, Train Loss: 81.9551 Test Loss: 50.2781\n", "Epoch:6, Train Loss: 78.3991 Test Loss: 49.9035\n", "Epoch:8, Train Loss: 75.8110 Test Loss: 45.7197\n", "Epoch:10, Train Loss: 74.3838 Test Loss: 40.5566\n" ] } ], "source": [ "test_interval = 2\n", "num_epochs = 10\n", "\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", "\n", " for x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t, y in zip(\n", " x_0_train,\n", " x_1_train,\n", " x_2_train,\n", " adjacency_1_train,\n", " incidence_2_train,\n", " incidence_1_t_train,\n", " y_train,\n", " strict=True,\n", " ):\n", " x_0, x_1, x_2, y = (\n", " torch.tensor(x_0).float().to(device),\n", " torch.tensor(x_1).float().to(device),\n", " torch.tensor(x_2).float().to(device),\n", " torch.tensor([y]).float().to(device),\n", " )\n", "\n", " adjacency_1 = adjacency_1.float().to(device)\n", " incidence_2 = incidence_2.float().to(device)\n", " incidence_1_t = incidence_1_t.float().to(device)\n", "\n", " optimizer.zero_grad()\n", " y_hat = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)\n", " loss = criterion(y_hat, y)\n", " loss.backward()\n", " optimizer.step()\n", " epoch_loss.append(loss.item())\n", "\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", " train_mean_loss = np.mean(epoch_loss)\n", " for x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t, y in zip(\n", " x_0_test,\n", " x_1_test,\n", " x_2_test,\n", " adjacency_1_test,\n", " incidence_2_test,\n", " incidence_1_t_test,\n", " y_test,\n", " strict=True,\n", " ):\n", " x_0, x_1, x_2, y = (\n", " torch.tensor(x_0).float().to(device),\n", " torch.tensor(x_1).float().to(device),\n", " torch.tensor(x_2).float().to(device),\n", " torch.tensor([y]).float().to(device),\n", " )\n", "\n", " adjacency_1 = adjacency_1.float().to(device)\n", " incidence_2 = incidence_2.float().to(device)\n", " incidence_1_t = incidence_1_t.float().to(device)\n", "\n", " y_hat = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)\n", " test_loss = criterion(y_hat, y)\n", " print(\n", " f\"Epoch:{epoch_i}, Train Loss: {train_mean_loss:.4f} Test Loss: {test_loss:.4f}\",\n", " flush=True,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "venv_tmx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 0 }