Source code for topobenchmarkx.transforms.liftings.graph2cell.cycle

"""This module implements the cycle lifting for graphs to cell complexes."""

import networkx as nx
import torch_geometric
from toponetx.classes import CellComplex

from topobenchmarkx.transforms.liftings.graph2cell.base import (
    Graph2CellLifting,
)


[docs] class CellCycleLifting(Graph2CellLifting): r"""Lift graphs to cell complexes. The algorithm creates 2-cells by identifying the cycles and considering them as 2-cells. Parameters ---------- max_cell_length : int, optional The maximum length of the cycles to be lifted. Default is None. **kwargs : optional Additional arguments for the class. """ def __init__(self, max_cell_length=None, **kwargs): super().__init__(**kwargs) self.complex_dim = 2 self.max_cell_length = max_cell_length
[docs] def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Find the cycles of a graph and lifts them to 2-cells. Parameters ---------- data : torch_geometric.data.Data The input data to be lifted. Returns ------- dict The lifted topology. """ G = self._generate_graph_from_data(data) cycles = nx.cycle_basis(G) cell_complex = CellComplex(G) # Eliminate self-loop cycles cycles = [cycle for cycle in cycles if len(cycle) != 1] # Eliminate cycles that are greater than the max_cell_lenght if self.max_cell_length is not None: cycles = [ cycle for cycle in cycles if len(cycle) <= self.max_cell_length ] if len(cycles) != 0: cell_complex.add_cells_from(cycles, rank=self.complex_dim) return self._get_lifted_topology(cell_complex, G)