{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train a SCCNN\n", "\n", "In this notebook, we will create and train a convolutional neural network in the simplicial complex domain, as proposed in the paper by [Yang et. al : Convolutional Learning on Simplicial Complexes (2023)](https://arxiv.org/abs/2301.11163). \n", "\n", "### We train the model to perform:\n", " 1. Complex classification using the shrec16 benchmark dataset.\n", " 2. Node classification using the karate dataset \n", " \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simplicial Complex Convolutional Neural Networks [SCCNN]\n", "\n", "SCCNN extends the SCNN to the complex domain by accounting for inter-simplicial connections, i.e., contributions from simplices of adjacent orders. \n", "\n", "For example, we consider SCCNN layers in an SC of order two. At layer $t$, given the inputs on nodes, edges and faces, $\\mathbf{h}_{t}^0,\\mathbf{h}_{t}^1$ and $\\mathbf{h}_{t}^2$, the SCCNN layer contains the following\n", "$$\n", " \\mathbf{h}_{t+1}^1 = \\sigma \\bigg[ \\mathbf{F}_{t,\\downarrow} \\mathbf{B}_{1}^\\top \\mathbf{h}_{t}^{0} + \\mathbf{F}_{t} \\mathbf{h}_t^1 + \\mathbf{F}_{t,\\uparrow} \\mathbf{B}_{2} \\mathbf{h}_t^{2} \\bigg] \n", "$$\n", "where $\\mathbf{F}_t$ is the simplicial convolutional filter defined in the edge space, and $\\mathbf{F}_{t,\\downarrow}$ and $\\mathbf{F}_{t,\\uparrow}$ are the convolutional filters based on, respectively, only the lower and upper Laplacians. They are given by \n", "$$\n", " \\mathbf{F}_{t} = {\\theta}_t + \\sum_{p_d=1}^{P_d} {\\theta}_{t,p_d} (\\mathbf{L}_{\\downarrow,1})^{p_d} + \\sum_{p_u=1}^{P_u} {\\theta}_{t,p_u} (\\mathbf{L}_{\\uparrow,1})^{p_u} \n", "$$\n", "$$\n", " \\mathbf{F}_{t,\\downarrow} = {\\theta}_t + \\sum_{p_d=1}^{P_d} {\\theta}_{t,p_d} (\\mathbf{L}_{\\downarrow,1})^{p_d} \n", " \\text{ and } \n", " \\mathbf{F}_{t,\\uparrow} = {\\theta}_t + \\sum_{p_u=1}^{P_u} {\\theta}_{t,p_u} (\\mathbf{L}_{\\uparrow,1})^{p_u} \n", "$$\n", "\n", "Likewise, for the node output, we have \n", "$$\n", " \\mathbf{h}_{t+1}^0 = \\sigma \\bigg[ \\mathbf{F}_{t} \\mathbf{h}_t^0 + \\mathbf{F}_{t,\\uparrow} \\mathbf{B}_{1} \\mathbf{h}_t^{1} \\bigg]\n", "$$\n", "where $\\mathbf{F}_t$ and $\\mathbf{F}_{t,\\uparrow}$ are two graph filters essentially. \n", "\n", "For the face output, we have \n", "$$\n", " \\mathbf{h}_{t+1}^2 = \\sigma \\bigg[ \\mathbf{F}_{t} \\mathbf{h}_{t}^{2} + \\mathbf{F}_{t,\\downarrow} \\mathbf{B}_{2}^\\top \\mathbf{h}_t^{1} \\bigg]\n", "$$\n", "where $\\mathbf{F}_t$ and $\\mathbf{F}_{t,\\downarrow}$ are two simplicial filters defined in the triangle (face) space. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Complex Classification" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "import numpy as np\n", "import toponetx.datasets as datasets\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "\n", "from topomodelx.nn.simplicial.sccnn import SCCNN\n", "from topomodelx.utils.sparse import from_sparse\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-processing\n", "\n", "### Import shrec dataset ##\n", "\n", "We must first lift our graph dataset into the simplicial complex domain." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading shrec 16 small dataset...\n", "\n", "done!\n" ] } ], "source": [ "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n", "shrec = {key: np.array(value) for key, value in shrec.items()}\n", "x_0s = shrec[\"node_feat\"]\n", "x_1s = shrec[\"edge_feat\"]\n", "x_2s = shrec[\"face_feat\"]\n", "\n", "ys = shrec[\"label\"]\n", "simplexes = shrec[\"complexes\"]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(6, 10, 7)\n" ] } ], "source": [ "in_channels_0 = x_0s[-1].shape[1]\n", "in_channels_1 = x_1s[-1].shape[1]\n", "in_channels_2 = x_2s[-1].shape[1]\n", "\n", "in_channels_all = (in_channels_0, in_channels_1, in_channels_2)\n", "print(in_channels_all)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Define Neighborhood Strctures\n", "Get incidence matrices $\\mathbf{B}_1,\\mathbf{B}_2$ and Hodge Laplacians $\\mathbf{L}_0, \\mathbf{L}_1$ and $\\mathbf{L}_2$.\n", "\n", "Note that the original paper considered the weighted versions of these operators. However, the current TOPONETX package does not provide such feature yet." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "max_rank = 2 # the order of the SC is two\n", "incidence_1_list = []\n", "incidence_2_list = []\n", "\n", "laplacian_0_list = []\n", "laplacian_down_1_list = []\n", "laplacian_up_1_list = []\n", "laplacian_2_list = []\n", "\n", "for simplex in simplexes:\n", " incidence_1 = simplex.incidence_matrix(rank=1)\n", " incidence_2 = simplex.incidence_matrix(rank=2)\n", " laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)\n", " laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)\n", " laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)\n", " laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)\n", "\n", " incidence_1 = from_sparse(incidence_1)\n", " incidence_2 = from_sparse(incidence_2)\n", " laplacian_0 = from_sparse(laplacian_0)\n", " laplacian_down_1 = from_sparse(laplacian_down_1)\n", " laplacian_up_1 = from_sparse(laplacian_up_1)\n", " laplacian_2 = from_sparse(laplacian_2)\n", "\n", " incidence_1_list.append(incidence_1)\n", " incidence_2_list.append(incidence_2)\n", " laplacian_0_list.append(laplacian_0)\n", " laplacian_down_1_list.append(laplacian_down_1)\n", " laplacian_up_1_list.append(laplacian_up_1)\n", " laplacian_2_list.append(laplacian_2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Create and Train the Neural Network\n", "\n", "We specify the model with our pre-made neighborhood structures and specify an optimizer." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " def __init__(\n", " self,\n", " in_channels_all,\n", " hidden_channels_all,\n", " out_channels,\n", " conv_order,\n", " max_rank,\n", " n_layers=2,\n", " ):\n", " super().__init__()\n", " self.base_model = SCCNN(\n", " in_channels_all=in_channels_all,\n", " hidden_channels_all=hidden_channels_all,\n", " conv_order=conv_order,\n", " sc_order=max_rank,\n", " n_layers=n_layers,\n", " )\n", " out_channels_0, out_channels_1, out_channels_2 = hidden_channels_all\n", " self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n", " self.out_linear_1 = torch.nn.Linear(out_channels_1, out_channels)\n", " self.out_linear_2 = torch.nn.Linear(out_channels_2, out_channels)\n", "\n", " def forward(self, x_all, laplacian_all, incidence_all):\n", " x_all = self.base_model(x_all, laplacian_all, incidence_all)\n", " x_0, x_1, x_2 = x_all\n", "\n", " x_0 = self.out_linear_0(x_0)\n", " x_1 = self.out_linear_1(x_1)\n", " x_2 = self.out_linear_2(x_2)\n", "\n", " # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n", " two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n", " two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n", " one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n", " one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n", " zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n", " zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n", " # Return the sum of the averages\n", " return (\n", " two_dimensional_cells_mean\n", " + one_dimensional_cells_mean\n", " + zero_dimensional_cells_mean\n", " )" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Network(\n", " (base_model): SCCNN(\n", " (in_linear_0): Linear(in_features=6, out_features=16, bias=True)\n", " (in_linear_1): Linear(in_features=10, out_features=16, bias=True)\n", " (in_linear_2): Linear(in_features=7, out_features=16, bias=True)\n", " (layers): ModuleList(\n", " (0-1): 2 x SCCNNLayer()\n", " )\n", " )\n", " (out_linear_0): Linear(in_features=16, out_features=1, bias=True)\n", " (out_linear_1): Linear(in_features=16, out_features=1, bias=True)\n", " (out_linear_2): Linear(in_features=16, out_features=1, bias=True)\n", ")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n", " warnings.warn(warning.format(ret))\n" ] } ], "source": [ "conv_order = 2\n", "intermediate_channels_all = (16, 16, 16)\n", "num_layers = 2\n", "out_channels = 1 # num classes\n", "\n", "model = Network(\n", " in_channels_all=in_channels_all,\n", " hidden_channels_all=intermediate_channels_all,\n", " out_channels=out_channels,\n", " conv_order=conv_order,\n", " max_rank=max_rank,\n", " n_layers=num_layers,\n", ")\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "loss_fn = torch.nn.MSELoss(size_average=True, reduction=\"mean\")\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "test_size = 0.2\n", "x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n", "x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)\n", "x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)\n", "\n", "incidence_1_train, incidence_1_test = train_test_split(\n", " incidence_1_list, test_size=test_size, shuffle=False\n", ")\n", "incidence_2_train, incidence_2_test = train_test_split(\n", " incidence_2_list, test_size=test_size, shuffle=False\n", ")\n", "laplacian_0_train, laplacian_0_test = train_test_split(\n", " laplacian_0_list, test_size=test_size, shuffle=False\n", ")\n", "laplacian_down_1_train, laplacian_down_1_test = train_test_split(\n", " laplacian_down_1_list, test_size=test_size, shuffle=False\n", ")\n", "laplacian_up_1_train, laplacian_up_1_test = train_test_split(\n", " laplacian_up_1_list, test_size=test_size, shuffle=False\n", ")\n", "laplacian_2_train, laplacian_2_test = train_test_split(\n", " laplacian_2_list, test_size=test_size, shuffle=False\n", ")\n", "\n", "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We train the SCCNN using low amount of epochs: we keep training minimal for the purpose of rapid testing." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", " return F.mse_loss(input, target, reduction=self.reduction)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 loss: 399008.0930\n", "Test_loss: 926.6080\n", "Epoch: 2 loss: 477.8325\n", "Test_loss: 204.4115\n", "Epoch: 3 loss: 299.4982\n", "Test_loss: 243.1712\n", "Epoch: 4 loss: 202.5915\n", "Test_loss: 302.4839\n", "Epoch: 5 loss: 147.9002\n", "Test_loss: 311.9497\n" ] } ], "source": [ "test_interval = 1\n", "num_epochs = 5\n", "\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", " for (\n", " x_0,\n", " x_1,\n", " x_2,\n", " incidence_1,\n", " incidence_2,\n", " laplacian_0,\n", " laplacian_down_1,\n", " laplacian_up_1,\n", " laplacian_2,\n", " y,\n", " ) in zip(\n", " x_0_train,\n", " x_1_train,\n", " x_2_train,\n", " incidence_1_train,\n", " incidence_2_train,\n", " laplacian_0_train,\n", " laplacian_down_1_train,\n", " laplacian_up_1_train,\n", " laplacian_2_train,\n", " y_train,\n", " strict=False,\n", " ):\n", " x_0 = torch.tensor(x_0)\n", " x_1 = torch.tensor(x_1)\n", " x_2 = torch.tensor(x_2)\n", " y = torch.tensor(y, dtype=torch.float)\n", " optimizer.zero_grad()\n", " x_all = (x_0.float(), x_1.float(), x_2.float())\n", " laplacian_all = (laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2)\n", " incidence_all = (incidence_1, incidence_2)\n", "\n", " y_hat = model(x_all, laplacian_all, incidence_all)\n", "\n", " # print(y_hat)\n", " loss = loss_fn(y_hat, y)\n", "\n", " epoch_loss.append(loss.item())\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\n", " f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}\",\n", " flush=True,\n", " )\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", " for (\n", " x_0,\n", " x_1,\n", " x_2,\n", " incidence_1,\n", " incidence_2,\n", " laplacian_0,\n", " laplacian_down_1,\n", " laplacian_up_1,\n", " laplacian_2,\n", " y,\n", " ) in zip(\n", " x_0_test,\n", " x_1_test,\n", " x_2_test,\n", " incidence_1_test,\n", " incidence_2_test,\n", " laplacian_0_test,\n", " laplacian_down_1_test,\n", " laplacian_up_1_test,\n", " laplacian_2_test,\n", " y_test,\n", " strict=False,\n", " ):\n", " x_0 = torch.tensor(x_0)\n", " x_1 = torch.tensor(x_1)\n", " x_2 = torch.tensor(x_2)\n", " y = torch.tensor(y, dtype=torch.float)\n", " optimizer.zero_grad()\n", " x_all = (x_0.float(), x_1.float(), x_2.float())\n", " laplacian_all = (\n", " laplacian_0,\n", " laplacian_down_1,\n", " laplacian_up_1,\n", " laplacian_2,\n", " )\n", " incidence_all = (incidence_1, incidence_2)\n", "\n", " y_hat = model(x_all, laplacian_all, incidence_all)\n", "\n", " loss = loss_fn(y_hat, y)\n", " print(f\"Test_loss: {loss:.4f}\", flush=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Node Classification " ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "import toponetx.datasets.graph as graph\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "## Import dataset ##\n", "\n", "The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.\n", "\n", "We must first lift our graph dataset into the simplicial complex domain." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4\n", "4\n" ] } ], "source": [ "dataset = graph.karate_club(complex_type=\"simplicial\")\n", "print(dataset)\n", "max_rank = dataset.dim\n", "print(max_rank)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Neighborhood Strctures\n", "Get incidence matrices $\\mathbf{B}_1,\\mathbf{B}_2$ and Hodge Laplacians $\\mathbf{L}_0, \\mathbf{L}_1$ and $\\mathbf{L}_2$.\n", "\n", "Note that the original paper considered the weighted versions of these operators. However, the current TOPONETX package does not provide such feature yet." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The incidence matrix B1 has shape: (34, 78).\n", "The incidence matrix B2 has shape: (78, 45).\n" ] } ], "source": [ "incidence_1 = dataset.incidence_matrix(rank=1)\n", "incidence_2 = dataset.incidence_matrix(rank=2)\n", "\n", "print(f\"The incidence matrix B1 has shape: {incidence_1.shape}.\")\n", "print(f\"The incidence matrix B2 has shape: {incidence_2.shape}.\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "laplacian_0 = dataset.hodge_laplacian_matrix(rank=0)\n", "laplacian_down_1 = dataset.down_laplacian_matrix(rank=1)\n", "laplacian_up_1 = dataset.up_laplacian_matrix(rank=1)\n", "laplacian_down_2 = dataset.down_laplacian_matrix(rank=2)\n", "laplacian_up_2 = dataset.up_laplacian_matrix(rank=2)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "laplacian_0 = from_sparse(laplacian_0)\n", "laplacian_down_1 = from_sparse(laplacian_down_1)\n", "laplacian_up_1 = from_sparse(laplacian_up_1)\n", "laplacian_down_2 = from_sparse(laplacian_down_2)\n", "laplacian_up_2 = from_sparse(laplacian_up_2)\n", "\n", "incidence_1 = from_sparse(incidence_1)\n", "incidence_2 = from_sparse(incidence_2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import signal ##\n", "\n", "We retrieve an input signal on the nodes, edges and faces. The signal will have shape $n_\\text{simplicial} \\times$ in_channels, where in_channels is the dimension of each simplicial's feature. Here, we have in_channels = channels_nodes $ = 2$. " ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "\"\"\"A function to obtain features based on the input: rank\n", "\"\"\"\n", "\n", "\n", "def get_simplicial_features(dataset, rank):\n", " if rank == 0:\n", " which_feat = \"node_feat\"\n", " elif rank == 1:\n", " which_feat = \"edge_feat\"\n", " elif rank == 2:\n", " which_feat = \"face_feat\"\n", " else:\n", " raise ValueError(\n", " \"input dimension must be 0, 1 or 2, because features are supported on nodes, edges and faces\"\n", " )\n", "\n", " x = list(dataset.get_simplex_attributes(which_feat).values())\n", " return torch.tensor(np.stack(x))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are 34 nodes with features of dimension 2.\n", "There are 78 edges with features of dimension 2.\n", "There are 45 faces with features of dimension 2.\n" ] } ], "source": [ "x_0 = get_simplicial_features(dataset, rank=0)\n", "x_1 = get_simplicial_features(dataset, rank=1)\n", "x_2 = get_simplicial_features(dataset, rank=2)\n", "print(f\"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.\")\n", "print(f\"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.\")\n", "print(f\"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define binary labels\n", "We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.\n", "\n", "We convert the binary labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "y = np.array(\n", " [\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 0,\n", " 0,\n", " 1,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " ]\n", ")\n", "y_true = np.zeros((34, 2))\n", "y_true[:, 0] = y\n", "y_true[:, 1] = 1 - y\n", "y_train = y_true[:30]\n", "y_test = y_true[-4:]\n", "\n", "y_train = torch.from_numpy(y_train)\n", "y_test = torch.from_numpy(y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Create and Train the Neural Network\n", "\n", "We specify the model with our pre-made neighborhood structures and specify an optimizer." ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " def __init__(\n", " self,\n", " in_channels_all,\n", " hidden_channels_all,\n", " out_channels,\n", " conv_order,\n", " max_rank,\n", " update_func=None,\n", " n_layers=2,\n", " ):\n", " super().__init__()\n", " self.base_model = SCCNN(\n", " in_channels_all=in_channels_all,\n", " hidden_channels_all=hidden_channels_all,\n", " conv_order=conv_order,\n", " sc_order=max_rank,\n", " update_func=update_func,\n", " n_layers=n_layers,\n", " )\n", " out_channels_0, _, _ = hidden_channels_all\n", " self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n", "\n", " def forward(self, x_all, laplacian_all, incidence_all):\n", " x_all = self.base_model(x_all, laplacian_all, incidence_all)\n", " x_0, _, _ = x_all\n", "\n", " \"\"\"\n", " We pass the output on the nodes to a linear layer and use that to generate a probability label for nodes\n", " \"\"\"\n", " x_0, _, _ = x_all\n", " logits = self.out_linear_0(x_0)\n", "\n", " return torch.sigmoid(logits)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "\"\"\"Obtain the initial features on all simplices\"\"\"\n", "x_all = (x_0, x_1, x_2)\n", "\n", "conv_order = 2\n", "in_channels_all = (x_0.shape[-1], x_1.shape[-1], x_2.shape[-1])\n", "intermediate_channels_all = (16, 16, 16)\n", "num_layers = 2\n", "out_channels = 2 # num classes\n", "\n", "laplacian_all = (\n", " laplacian_0,\n", " laplacian_down_1,\n", " laplacian_up_1,\n", " laplacian_down_2,\n", " laplacian_up_2,\n", ")\n", "\n", "incidence_all = (incidence_1, incidence_2)\n", "\n", "model = Network(\n", " in_channels_all=in_channels_all,\n", " hidden_channels_all=intermediate_channels_all,\n", " out_channels=out_channels,\n", " conv_order=conv_order,\n", " max_rank=max_rank,\n", " update_func=\"sigmoid\",\n", " n_layers=num_layers,\n", ")\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 loss: 0.6788 Train_acc: 0.4667\n", "Epoch: 2 loss: 0.6248 Train_acc: 0.8000\n", "Epoch: 3 loss: 0.6677 Train_acc: 0.7667\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 4 loss: 0.5903 Train_acc: 0.8000\n", "Epoch: 5 loss: 0.5934 Train_acc: 0.8000\n", "Epoch: 6 loss: 0.5666 Train_acc: 0.8000\n", "Epoch: 7 loss: 0.5496 Train_acc: 0.8000\n", "Epoch: 8 loss: 0.5381 Train_acc: 0.8000\n", "Epoch: 9 loss: 0.5306 Train_acc: 0.8000\n", "Epoch: 10 loss: 0.5257 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 11 loss: 0.5223 Train_acc: 0.8000\n", "Epoch: 12 loss: 0.5198 Train_acc: 0.8000\n", "Epoch: 13 loss: 0.5182 Train_acc: 0.8000\n", "Epoch: 14 loss: 0.5170 Train_acc: 0.8000\n", "Epoch: 15 loss: 0.5162 Train_acc: 0.8000\n", "Epoch: 16 loss: 0.5156 Train_acc: 0.8000\n", "Epoch: 17 loss: 0.5151 Train_acc: 0.8000\n", "Epoch: 18 loss: 0.5147 Train_acc: 0.8000\n", "Epoch: 19 loss: 0.5144 Train_acc: 0.8000\n", "Epoch: 20 loss: 0.5142 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 21 loss: 0.5140 Train_acc: 0.8000\n", "Epoch: 22 loss: 0.5139 Train_acc: 0.8000\n", "Epoch: 23 loss: 0.5138 Train_acc: 0.8000\n", "Epoch: 24 loss: 0.5137 Train_acc: 0.8000\n", "Epoch: 25 loss: 0.5136 Train_acc: 0.8000\n", "Epoch: 26 loss: 0.5135 Train_acc: 0.8000\n", "Epoch: 27 loss: 0.5135 Train_acc: 0.8000\n", "Epoch: 28 loss: 0.5134 Train_acc: 0.8000\n", "Epoch: 29 loss: 0.5134 Train_acc: 0.8000\n", "Epoch: 30 loss: 0.5133 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 31 loss: 0.5133 Train_acc: 0.8000\n", "Epoch: 32 loss: 0.5133 Train_acc: 0.8000\n", "Epoch: 33 loss: 0.5132 Train_acc: 0.8000\n", "Epoch: 34 loss: 0.5132 Train_acc: 0.8000\n", "Epoch: 35 loss: 0.5132 Train_acc: 0.8000\n", "Epoch: 36 loss: 0.5132 Train_acc: 0.8000\n", "Epoch: 37 loss: 0.5131 Train_acc: 0.8000\n", "Epoch: 38 loss: 0.5131 Train_acc: 0.8000\n", "Epoch: 39 loss: 0.5131 Train_acc: 0.8000\n", "Epoch: 40 loss: 0.5131 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 41 loss: 0.5130 Train_acc: 0.8000\n", "Epoch: 42 loss: 0.5130 Train_acc: 0.8000\n", "Epoch: 43 loss: 0.5130 Train_acc: 0.8000\n", "Epoch: 44 loss: 0.5130 Train_acc: 0.8000\n", "Epoch: 45 loss: 0.5129 Train_acc: 0.8000\n", "Epoch: 46 loss: 0.5129 Train_acc: 0.8000\n", "Epoch: 47 loss: 0.5129 Train_acc: 0.8000\n", "Epoch: 48 loss: 0.5128 Train_acc: 0.8000\n", "Epoch: 49 loss: 0.5128 Train_acc: 0.8000\n", "Epoch: 50 loss: 0.5128 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 51 loss: 0.5127 Train_acc: 0.8000\n", "Epoch: 52 loss: 0.5127 Train_acc: 0.8000\n", "Epoch: 53 loss: 0.5126 Train_acc: 0.8000\n", "Epoch: 54 loss: 0.5126 Train_acc: 0.8000\n", "Epoch: 55 loss: 0.5125 Train_acc: 0.8000\n", "Epoch: 56 loss: 0.5124 Train_acc: 0.8000\n", "Epoch: 57 loss: 0.5124 Train_acc: 0.8000\n", "Epoch: 58 loss: 0.5123 Train_acc: 0.8000\n", "Epoch: 59 loss: 0.5122 Train_acc: 0.8000\n", "Epoch: 60 loss: 0.5121 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 61 loss: 0.5120 Train_acc: 0.8000\n", "Epoch: 62 loss: 0.5119 Train_acc: 0.8000\n", "Epoch: 63 loss: 0.5118 Train_acc: 0.8000\n", "Epoch: 64 loss: 0.5116 Train_acc: 0.8000\n", "Epoch: 65 loss: 0.5115 Train_acc: 0.8000\n", "Epoch: 66 loss: 0.5113 Train_acc: 0.8000\n", "Epoch: 67 loss: 0.5111 Train_acc: 0.8000\n", "Epoch: 68 loss: 0.5109 Train_acc: 0.8000\n", "Epoch: 69 loss: 0.5107 Train_acc: 0.8000\n", "Epoch: 70 loss: 0.5105 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 71 loss: 0.5103 Train_acc: 0.8000\n", "Epoch: 72 loss: 0.5101 Train_acc: 0.8000\n", "Epoch: 73 loss: 0.5099 Train_acc: 0.8000\n", "Epoch: 74 loss: 0.5098 Train_acc: 0.8000\n", "Epoch: 75 loss: 0.5096 Train_acc: 0.8000\n", "Epoch: 76 loss: 0.5094 Train_acc: 0.8000\n", "Epoch: 77 loss: 0.5092 Train_acc: 0.8000\n", "Epoch: 78 loss: 0.5089 Train_acc: 0.8000\n", "Epoch: 79 loss: 0.5085 Train_acc: 0.8000\n", "Epoch: 80 loss: 0.5082 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 81 loss: 0.5078 Train_acc: 0.8000\n", "Epoch: 82 loss: 0.5075 Train_acc: 0.8000\n", "Epoch: 83 loss: 0.5071 Train_acc: 0.8000\n", "Epoch: 84 loss: 0.5068 Train_acc: 0.8000\n", "Epoch: 85 loss: 0.5064 Train_acc: 0.8000\n", "Epoch: 86 loss: 0.5060 Train_acc: 0.8000\n", "Epoch: 87 loss: 0.5056 Train_acc: 0.8000\n", "Epoch: 88 loss: 0.5051 Train_acc: 0.8000\n", "Epoch: 89 loss: 0.5046 Train_acc: 0.8000\n", "Epoch: 90 loss: 0.5041 Train_acc: 0.8000\n", "Test_acc: 0.5000\n", "Epoch: 91 loss: 0.5037 Train_acc: 0.8000\n", "Epoch: 92 loss: 0.5032 Train_acc: 0.8000\n", "Epoch: 93 loss: 0.5027 Train_acc: 0.8000\n", "Epoch: 94 loss: 0.5022 Train_acc: 0.8000\n", "Epoch: 95 loss: 0.5016 Train_acc: 0.8000\n", "Epoch: 96 loss: 0.5011 Train_acc: 0.8000\n", "Epoch: 97 loss: 0.5005 Train_acc: 0.8000\n", "Epoch: 98 loss: 0.5000 Train_acc: 0.8000\n", "Epoch: 99 loss: 0.4994 Train_acc: 0.8000\n", "Epoch: 100 loss: 0.4989 Train_acc: 0.8000\n", "Test_acc: 0.5000\n" ] } ], "source": [ "test_interval = 10\n", "num_epochs = 100\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", " optimizer.zero_grad()\n", " y_hat = model(x_all, laplacian_all, incidence_all)\n", " y_hat = torch.softmax(y_hat, dim=1)\n", " loss = torch.nn.functional.binary_cross_entropy(\n", " y_hat[: len(y_train)].float(), y_train.float()\n", " )\n", " epoch_loss.append(loss.item())\n", " loss.backward()\n", " optimizer.step()\n", "\n", " y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))\n", " accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()\n", " print(\n", " f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}\",\n", " flush=True,\n", " )\n", " if epoch_i % test_interval == 0:\n", " with torch.no_grad():\n", " y_hat_test = model(x_all, laplacian_all, incidence_all)\n", " # Projection to node-level\n", " y_hat_test = torch.softmax(y_hat_test, dim=1)\n", " y_pred_test = torch.where(\n", " y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)\n", " )\n", " test_accuracy = (\n", " torch.eq(y_pred_test[-len(y_test) :], y_test)\n", " .all(dim=1)\n", " .float()\n", " .mean()\n", " .item()\n", " )\n", " print(f\"Test_acc: {test_accuracy:.4f}\", flush=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 4 }