{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from sklearn.model_selection import train_test_split\n", "from toponetx.classes.simplicial_complex import SimplicialComplex\n", "from torch_geometric.datasets import TUDataset\n", "from torch_geometric.utils.convert import to_networkx\n", "\n", "from topomodelx.nn.hypergraph.unigin import UniGIN\n", "from topomodelx.utils.sparse import from_sparse\n", "\n", "torch.manual_seed(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If GPU's are available, we will make use of them. Otherwise, this will run on CPU." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train a UNIGIN TNN" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-processing\n", "\n", "## Import data ##\n", "\n", "The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.\n", "\n", "We will also retrieve:\n", "- input signal on the nodes for each of these hypergraphs, as that will be what we feed the model in input\n", "- the binary label associated to the hypergraph" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dataset = TUDataset(root=\"/tmp/MUTAG\", name=\"MUTAG\", use_edge_attr=True)\n", "dataset = dataset[:100]\n", "hg_list = []\n", "x_1_list = []\n", "y_list = []\n", "for graph in dataset:\n", " hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()\n", " hg_list.append(hg)\n", " x_1_list.append(graph.x.to(device))\n", " y_list.append(graph.y.to(device))\n", "\n", "incidence_1_list = []\n", "for hg in hg_list:\n", " incidence_1 = hg.incidence_matrix()\n", " incidence_1 = from_sparse(incidence_1)\n", " incidence_1_list.append(incidence_1.to(device))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Create the Neural Network\n", "\n", "Define the network that initializes the base model and sets up the readout operation.\n", "Different downstream tasks might require different pooling procedures.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " \"\"\"Network class that initializes the base modelo and readout layer.\n", "\n", " Base model parameters:\n", " ----------\n", " Reqired:\n", " in_channels : int\n", " Dimension of the input features.\n", " hidden_channels : int\n", " Dimension of the hidden features.\n", "\n", " Optitional:\n", " **kwargs : dict\n", " Additional arguments for the base model.\n", "\n", " Readout layer parameters:\n", " ----------\n", " out_channels : int\n", " Dimension of the output features.\n", " task_level : str\n", " Level of the task. Either \"graph\" or \"node\".\n", " \"\"\"\n", "\n", " def __init__(\n", " self, in_channels, hidden_channels, out_channels, task_level=\"graph\", **kwargs\n", " ):\n", " super().__init__()\n", "\n", " # Define the model\n", " self.base_model = UniGIN(\n", " in_channels=in_channels, hidden_channels=hidden_channels, **kwargs\n", " )\n", "\n", " # Readout\n", " self.linear = torch.nn.Linear(hidden_channels, out_channels)\n", " self.out_pool = task_level == \"graph\"\n", "\n", " def forward(self, x_0, incidence_1):\n", " # Base model\n", " x_0, x_1 = self.base_model(x_0, incidence_1)\n", "\n", " # Pool over all nodes in the hypergraph\n", " x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0\n", "\n", " return self.linear(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Base model hyperparameters\n", "in_channels = x_1_list[0].shape[1]\n", "hidden_channels = 32\n", "n_layers = 3\n", "mlp_num_layers = 1\n", "input_drop = 0.2\n", "layer_drop = 0.2\n", "\n", "# Readout hyperparameters\n", "out_channels = 2\n", "task_level = \"graph\"\n", "\n", "\n", "model = Network(\n", " in_channels=in_channels,\n", " hidden_channels=hidden_channels,\n", " input_drop=input_drop,\n", " layer_drop=layer_drop,\n", " n_layers=n_layers,\n", " out_channels=out_channels,\n", " task_level=task_level,\n", ").to(device)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train the Neural Network\n", "\n", "We specify the model, the loss, and an optimizer." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "crit = torch.nn.CrossEntropyLoss()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Split the dataset into train, val and test sets." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)\n", "incidence_1_train, incidence_1_test = train_test_split(\n", " incidence_1_list, test_size=0.2, shuffle=False\n", ")\n", "y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)\n", "\n", "x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)\n", "incidence_1_train, incidence_1_val = train_test_split(\n", " incidence_1_train, test_size=0.2, shuffle=False\n", ")\n", "y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "**Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.**" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 loss: 1261.591552734375\n", "Epoch 0 Validation accuracy: 0.5625\n", "Epoch 10 loss: 50.00477600097656\n", "Epoch 10 Validation accuracy: 0.5625\n", "Epoch 20 loss: 37.117366790771484\n", "Epoch 20 Validation accuracy: 0.4375\n", "Epoch 30 loss: 30.90342903137207\n", "Epoch 30 Validation accuracy: 0.3125\n", "Epoch 40 loss: 25.588314056396484\n", "Epoch 40 Validation accuracy: 0.5625\n", "Epoch 50 loss: 26.57889747619629\n", "Epoch 50 Validation accuracy: 0.625\n", "Epoch 60 loss: 19.412574768066406\n", "Epoch 60 Validation accuracy: 0.5625\n", "Test accuracy: 0.7\n" ] } ], "source": [ "test_interval = 10\n", "num_epochs = 10\n", "for epoch in range(num_epochs):\n", " model.train()\n", " optimizer.zero_grad()\n", " loss = 0\n", " for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train, strict=False):\n", " output = model(x_1, incidence_1)\n", " loss += crit(output.unsqueeze(0), y)\n", " loss.backward()\n", " optimizer.step()\n", " if epoch % test_interval == 0:\n", " print(f\"Epoch {epoch} loss: {loss.item()}\")\n", " model.eval()\n", " with torch.no_grad():\n", " correct = 0\n", " for x_1, incidence_1, y in zip(\n", " x_1_val, incidence_1_val, y_val, strict=False\n", " ):\n", " output = model(x_1, incidence_1)\n", " pred = torch.argmax(output)\n", " if pred == y:\n", " correct += 1\n", " print(f\"Epoch {epoch} Validation accuracy: {correct / len(y_val)}\")\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " correct = 0\n", " for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test, strict=False):\n", " output = model(x_1, incidence_1)\n", " pred = torch.argmax(output)\n", " if pred == y:\n", " correct += 1\n", " print(f\"Test accuracy: {correct / len(y_test)}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 4 }