{ "cells": [ { "cell_type": "markdown", "id": "b12508af-69b5-4663-b8c6-771853be4d15", "metadata": {}, "source": [ "# Tutorial: Cell2Vec - Embedding Cells using Topological Deep Learning\n", "\n", "\n", "In this tutorial, we explored Cell2Vec, a method for embedding cells in cell or simplicial complexes using topological deep learning. We implemented Cell2Vec using TopoEmbedX and demonstrated the visualization of cell embeddings.\n" ] }, { "cell_type": "markdown", "id": "ff32d87d-2411-4054-b6e8-f0774d7522b7", "metadata": {}, "source": [ "In this tutorial, we will explore Cell2Vec, a method for embedding cells using topological deep learning techniques. Node2Vec is a generalization of the DeepWalk algorithm to cell or simplicial complexes, enabling the generation of meaningful embeddings for cells of varying dimensions (0, 1, or 2) within a topological domain.\n", "\n", "### Table of Contents\n", "\n", "* Introduction\n", "* Understanding Node2Vec\n", "* Generalizing Node2Vec to Cell2Vec\n", "* Implementation in Python using TopoEmbedX\n", "* Visualization of Cell Embeddings\n", "* Conclusion and Further Steps" ] }, { "cell_type": "markdown", "id": "03c12f22-8c15-45dc-abaa-4d5956d61447", "metadata": {}, "source": [ "## Understanding Node2Vec\n", "Node2Vec is an algorithm used for generating embeddings in graph-structured data, such as social networks or citation networks. It extends the idea of Word2Vec to graphs, aiming to learn low-dimensional vector representations for nodes that preserve the network structure.\n", "\n", "The key concept in Node2Vec is the notion of biased random walks. Instead of performing simple random walks, Node2Vec introduces parameters to control the walk behavior. These parameters allow the random walker to balance between exploring the neighborhood locally (BFS-like) and jumping to far-away nodes (DFS-like). This biased random walk strategy helps in capturing the node neighborhood structure effectively.\n", "\n", "## Generalizing Node2Vec to Cell2Vec\n", "Cell2Vec extends the principles of Node2Vec to cell or simplicial complexes, where nodes represent cells (vertices, edges, or faces), and edges represent connections between cells. By leveraging biased random walks on the cell complex, Cell2Vec captures the topological structure effectively, enabling the generation of meaningful embeddings for cells.\n", "\n", "## Implementation in Python using TopoEmbedX\n", "To implement Cell2Vec, we will utilize the TopoEmbedX library, which provides tools for topological representation learning. We will demonstrate how to use TopoEmbedX to create a model, fit it to a cell complex, and obtain cell embeddings." ] }, { "cell_type": "code", "execution_count": 69, "id": "b7f8e498-d9b1-4e71-bbcc-bbacd49ec65c", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import toponetx as tnx\n", "\n", "import topoembedx as tex\n", "\n", "# Create a cell complex object with a few cells\n", "cc = tnx.classes.CellComplex([[1, 2, 3, 4, 5], [4, 5, 6]], ranks=2)\n", "\n", "# Create a model\n", "model = tex.Cell2Vec(dimensions=2)\n", "\n", "# Fit the model to the cell complex\n", "\n", "model.fit(cc, neighborhood_type=\"adj\", neighborhood_dim={\"rank\": 1, \"via_rank\": -1})\n", "\n", "# note that \"via_rank\" is ignored here and only considered when the complex is Combintatorial complex\n", "\n", "# Get the embeddings\n", "embedded_points = model.get_embedding(get_dict=True)\n", "\n", "\n", "# Prepare data for plotting\n", "x = [embedded_points[cell][0] for cell in embedded_points]\n", "y = [embedded_points[cell][1] for cell in embedded_points]\n", "cell_labels = [f\"Cell {cell}\" for cell in embedded_points]\n", "\n", "# Plotting\n", "plt.figure(figsize=(10, 8))\n", "plt.scatter(x, y, c=\"blue\", label=\"Projected Points\")\n", "\n", "# Annotate the points to correspond with cells\n", "for i, label in enumerate(cell_labels):\n", " plt.annotate(\n", " label, (x[i], y[i]), textcoords=\"offset points\", xytext=(0, 10), ha=\"center\"\n", " )\n", "\n", "# Label axes and add title\n", "plt.xlabel(\"X-axis\")\n", "plt.ylabel(\"Y-axis\")\n", "plt.title(\"Projection of Cell Complex in 2D\")\n", "\n", "# Display the plot\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "694be0ab-8bf6-4aad-b900-e1088dd3fbc9", "metadata": {}, "source": [ "Observe how the edge (4,5) is inbetweeen the edges belonging to the cell (1, 2, 3, 4, 5) whereas the edges (4,5) and (5,6) are isolated on the otherwise. This is because (4,5) is shared between the cells.\n", "\n", "\n", "Next we explore embedding 2 cells using Cell2Vec:" ] }, { "cell_type": "code", "execution_count": 86, "id": "447c3dfe-58f9-4ed5-bf40-d58ad65dd1ac", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import toponetx as tnx\n", "\n", "import topoembedx as tex\n", "\n", "# Create a cell complex object with a few cells\n", "cc = tnx.classes.CellComplex(\n", " [[1, 2, 3], [1, 2, 3, 4, 5], [7, 8, 9, 12], [7, 8, 9, 10], [4, 5, 6]], ranks=2\n", ")\n", "\n", "# Create a model\n", "model = tex.Cell2Vec(dimensions=2)\n", "\n", "# Fit the model to the cell complex\n", "model.fit(cc, neighborhood_type=\"coadj\", neighborhood_dim={\"rank\": 2, \"via_rank\": -1})\n", "\n", "# Get the embeddings\n", "embedded_points = model.get_embedding(get_dict=True)\n", "\n", "\n", "# Prepare data for plotting\n", "x = [embedded_points[cell][0] for cell in embedded_points]\n", "y = [embedded_points[cell][1] for cell in embedded_points]\n", "cell_labels = [f\"Cell {cell}\" for cell in embedded_points]\n", "\n", "# Plotting\n", "plt.figure(figsize=(10, 8))\n", "plt.scatter(x, y, c=\"blue\", label=\"Projected Points\")\n", "\n", "# Annotate the points to correspond with cells\n", "for i, label in enumerate(cell_labels):\n", " plt.annotate(\n", " label, (x[i], y[i]), textcoords=\"offset points\", xytext=(0, 10), ha=\"center\"\n", " )\n", "\n", "# Label axes and add title\n", "plt.xlabel(\"X-axis\")\n", "plt.ylabel(\"Y-axis\")\n", "plt.title(\"Projection of Cell Complex in 2D\")\n", "\n", "# Display the plot\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "6f2ae646-e504-4055-af4c-865236c529cd", "metadata": {}, "source": [ "### Refs\n", "\n", "(1) Mustafa Hajij,Kyle Istvan,and Ghada Zamzmi, Cell Complex Neural Networks. NeurIPS2020 Workshop TDA and Beyond, 2020." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }