Source code for topomodelx.nn.simplicial.scone

"""Neural network implementation of classification using SCoNe."""

import random
from itertools import product

import networkx as nx
import numpy as np
import toponetx as tnx
import torch
from scipy.spatial import Delaunay, distance
from torch import nn
from torch.utils.data.dataset import Dataset

from topomodelx.nn.simplicial.scone_layer import SCoNeLayer


[docs] def generate_complex( N: int = 100, *, rng: np.random.Generator | None = None ) -> tuple[tnx.SimplicialComplex, np.ndarray]: """Generate a simplicial complex as described. Generate a simplicial complex of dimension 2 as follows: 1. Uniformly sample N random points form the unit square and build the Delaunay triangulation. 2. Delete triangles contained in some pre-defined disks. Parameters ---------- N : int Number of vertices in the simplicial complex. rng : np.random.Generator, optional The random number generator to use, defaults to NumPy's default generator. """ if rng is None: rng = np.random.default_rng() points = rng.uniform(0, 1, size=(N, 2)) # Sort points by the sum of their coordinates c = np.sum(points, axis=1) order = np.argsort(c) points = points[order] tri = Delaunay(points) # Create Delaunay triangulation # Remove triangles having centroid inside the disks disk_centers = np.array([[0.3, 0.7], [0.7, 0.3]]) disk_radius = 0.15 simplices = [] indices_included = set() for simplex in tri.simplices: center = np.mean(points[simplex], axis=0) if ~np.any(distance.cdist([center], disk_centers) <= disk_radius, axis=1): # Centroid is not contained in some disk, so include it. simplices.append(simplex) indices_included |= set(simplex) # Re-index vertices before constructing the simplicial complex idx_dict = {i: j for j, i in enumerate(indices_included)} for i in range(len(simplices)): for j in range(3): simplices[i][j] = idx_dict[simplices[i][j]] sc = tnx.SimplicialComplex(simplices) coords = points[list(indices_included)] return sc, coords
[docs] def generate_trajectories( sc: tnx.SimplicialComplex, coords: np.ndarray, n_max: int = 1000 ) -> list[list[int]]: """Generate trajectories from nodes in the lower left corner to the upper right corner connected through a node in the middle.""" # Get indices for start points in the lower left corner, mid points in the center region and end points in the upper right corner. N = len(sc) start_nodes = list(range(int(0.2 * N))) mid_nodes = list(range(int(0.4 * N), int(0.5 * N))) end_nodes = list(range(int(0.8 * N), N)) all_triplets = list(product(start_nodes, mid_nodes, end_nodes)) assert ( len(all_triplets) >= n_max ), f"Only {len(all_triplets)} valid paths, but {n_max} requested. Try increasing the number of points in the simplicial complex." triplets = random.sample(all_triplets, n_max) # Compute pairwise distances and create a matrix representing the underlying graph. distance_matrix = distance.squareform(distance.pdist(coords)) graph = sc.adjacency_matrix(0).toarray() * distance_matrix G = nx.from_numpy_array(graph) # Find shortest paths trajectories = [] for s, m, e in triplets: path_1 = nx.shortest_path(G, s, m, weight="weight") path_2 = nx.shortest_path(G, m, e, weight="weight") trajectory = path_1[:-1] + path_2 trajectories.append(trajectory) return trajectories
[docs] class TrajectoriesDataset(Dataset): """Create a dataset of trajectories.""" def __init__( self, sc: tnx.SimplicialComplex, trajectories: list[list[int]] ) -> None: self.trajectories = trajectories self.sc = sc self.adjacency = torch.Tensor(sc.adjacency_matrix(0).toarray()) # Lookup table used to speed up vectorizing of trajectories self.edge_lookup_table = {} for i, edge in enumerate(self.sc.skeleton(1)): edge = tuple(edge) self.edge_lookup_table[edge] = (1, i) self.edge_lookup_table[edge[::-1]] = (-1, i) def __getitem__( self, index: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the index of the trajectory and its neighbors.""" trajectory = self.trajectories[index] vectorized_trajectory = self.vectorize_path( trajectory[:-1] ) # Discard the last node # Find neighbors of the last node in the trajectory (for use in the forward pass of SCoNe) neighbors_mask = ( torch.Tensor(self.adjacency[trajectory[-2]] > 0).float().unsqueeze(-1) ) last_node = torch.tensor(trajectory[-1]) return vectorized_trajectory, neighbors_mask, last_node def __len__(self) -> int: """Trajectories in the dataset.""" return len(self.trajectories)
[docs] def vectorize_path(self, path: list[int]) -> torch.Tensor: """Vectorize a path of nodes into a vector representation of the trajectory.""" # Create a vector representation of a trajectory. m = len(self.sc.skeleton(1)) c0 = torch.zeros((m, 1)) for j in range(len(path) - 1): edge = (path[j], path[j + 1]) sign, i = self.edge_lookup_table[edge] c0[i] = sign return c0
[docs] class SCoNe(nn.Module): """Neural network implementation of classification using SCoNe.""" def __init__(self, in_channels: int, hidden_channels: int, n_layers: int) -> None: super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.n_layers = n_layers # Stack multiple SCoNe layers with given hidden dimensions self.layers = nn.ModuleList( [SCoNeLayer(self.in_channels, self.hidden_channels)] ) for _ in range(self.n_layers - 1): self.layers.append(SCoNeLayer(self.hidden_channels, self.hidden_channels)) # Initialize parameters for layer in self.layers: layer.reset_parameters()
[docs] def forward( self, x: torch.Tensor, incidence_1: torch.Tensor, incidence_2: torch.Tensor ) -> torch.Tensor: """Forward pass through the network.""" for layer in self.layers: x = layer(x, incidence_1, incidence_2) return x