Source code for topomodelx.nn.hypergraph.hnhn_layer

"""Template Layer with two conv passing steps."""
from typing import Literal

import torch
from torch.nn.parameter import Parameter

from topomodelx.base.conv import Conv


[docs] class HNHNLayer(torch.nn.Module): """Layer of a Hypergraph Networks with Hyperedge Neurons (HNHN). Implementation of a simplified version of the HNHN layer proposed in [1]_. This layer is composed of two convolutional layers: 1. A convolutional layer sending messages from edges to nodes. 2. A convolutional layer sending messages from nodes to edges. The incidence matrices can be normalized usign the node and edge cardinality. Two hyperparameters alpha and beta, control the normalization strenght. The convolutional layers support the training of a bias term. Parameters ---------- in_channels : int Dimension of node features. hidden_channels : int Dimension of hidden features. incidence_1 : torch.sparse, shape = (n_nodes, n_edges) Incidence matrix mapping edges to nodes (B_1). use_bias : bool Flag controlling whether to use a bias term in the convolution. use_normalized_incidence : bool Flag controlling whether to normalize the incidence matrices. alpha : float Scalar controlling the importance of edge cardinality. beta : float Scalar controlling the importance of node cardinality. bias_gain : float Gain for the bias initialization. bias_init : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Controls the bias initialization method. **kwargs : optional Additional arguments for the layer modules. Notes ----- This is the architecture proposed for node classification. References ---------- .. [1] Dong, Sawin, Bengio. HNHN: hypergraph networks with hyperedge neurons. Graph Representation Learning and Beyond Workshop at ICML 2020. https://grlplus.github.io/papers/40.pdf .. [2] Papillon, Sanborn, Hajij, Miolane. Equations of topological neural networks (2023). https://github.com/awesome-tnns/awesome-tnns/ .. [3] Papillon, Sanborn, Hajij, Miolane. Architectures of topological deep learning: a survey on topological neural networks (2023). https://arxiv.org/abs/2304.10031. """ def __init__( self, in_channels, hidden_channels, incidence_1=None, use_bias: bool = True, use_normalized_incidence: bool = True, alpha: float = -1.5, beta: float = -0.5, bias_gain: float = 1.414, bias_init: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", **kwargs, ) -> None: super().__init__() self.use_bias = use_bias self.bias_init = bias_init self.bias_gain = bias_gain self.use_normalized_incidence = use_normalized_incidence self.incidence_1 = incidence_1 if incidence_1 is not None: self.incidence_1_transpose = incidence_1.transpose(1, 0) self.conv_0_to_1 = Conv( in_channels=in_channels, out_channels=hidden_channels, aggr_norm=False, update_func=None, ) self.conv_1_to_0 = Conv( in_channels=hidden_channels, out_channels=hidden_channels, aggr_norm=False, update_func=None, ) if self.use_bias: self.bias_1_to_0 = Parameter(torch.Tensor(1, hidden_channels)) self.bias_0_to_1 = Parameter(torch.Tensor(1, hidden_channels)) self.init_biases() if self.use_normalized_incidence: self.alpha = alpha self.beta = beta if incidence_1 is not None: self.n_nodes, self.n_edges = self.incidence_1.shape self.compute_normalization_matrices() self.normalize_incidence_matrices()
[docs] def compute_normalization_matrices(self) -> None: """Compute the normalization matrices for the incidence matrices.""" B1 = self.incidence_1.to_dense() edge_cardinality = (B1.sum(0)) ** self.alpha node_cardinality = (B1.sum(1)) ** self.beta # Compute D0_left_alpha_inverse self.D0_left_alpha_inverse = torch.zeros(self.n_nodes, self.n_nodes) for i_node in range(self.n_nodes): self.D0_left_alpha_inverse[i_node, i_node] = 1 / ( edge_cardinality[B1[i_node, :].bool()].sum() ) # Compute D1_left_beta_inverse self.D1_left_beta_inverse = torch.zeros(self.n_edges, self.n_edges) for i_edge in range(self.n_edges): self.D1_left_beta_inverse[i_edge, i_edge] = 1 / ( node_cardinality[B1[:, i_edge].bool()].sum() ) # Compute D1_right_alpha self.D1_right_alpha = torch.diag(edge_cardinality) # Compute D0_right_beta self.D0_right_beta = torch.diag(node_cardinality) return
[docs] def normalize_incidence_matrices(self) -> None: """Normalize the incidence matrices.""" self.incidence_1 = ( self.D0_left_alpha_inverse
[docs] @ self.incidence_1.to_dense() @ self.D1_right_alpha ).to_sparse() self.incidence_1_transpose = ( self.D1_left_beta_inverse @ self.incidence_1_transpose.to_dense() @ self.D0_right_beta ).to_sparse() return
def init_biases(self) -> None: """Initialize the bias.""" for bias in [self.bias_0_to_1, self.bias_1_to_0]: if self.bias_init == "xavier_uniform": torch.nn.init.xavier_uniform_(bias, gain=self.bias_gain) elif self.bias_init == "xavier_normal": torch.nn.init.xavier_normal_(bias, gain=self.bias_gain)
[docs] def reset_parameters(self) -> None: """Reset learnable parameters.""" self.conv_1_to_0.reset_parameters() self.conv_0_to_1.reset_parameters() if self.use_bias: self.init_biases()
[docs] def forward(self, x_0, incidence_1=None): r"""Forward computation. The forward pass was initially proposed in [1]_. Its equations are given in [2]_ and graphically illustrated in [3]_. The equations of one layer of this neural network are given by: .. math:: \begin{align*} &🟥 \quad m_{y \rightarrow x}^{(0 \rightarrow 1)} = \sigma((B_1^T \cdot W^{(0)})_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0)} + b^{t,(0)})\\ &🟥 \quad m_{y \rightarrow x}^{(1 \rightarrow 0)} = \sigma((B_1 \cdot W^{(1)})_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1)} + b^{t,(1)})\\ &🟧 \quad m_x^{(0 \rightarrow 1)} = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(0 \rightarrow 1)}\\ &🟧 \quad m_x^{(1 \rightarrow 0)} = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(1 \rightarrow 0)}\\ &🟩 \quad m_x^{(0)} = m_x^{(1 \rightarrow 0)}\\ &🟩 \quad m_x^{(1)} = m_x^{(0 \rightarrow 1)}\\ &🟦 \quad h_x^{t+1,(0)} = m_x^{(0)}\\ &🟦 \quad h_x^{t+1,(1)} = m_x^{(1)} \end{align*} Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, channels_node) Input features on the hypernodes. incidence_1 : torch.Tensor, shape = (n_nodes, n_edges) Incidence matrix mapping edges to nodes (B_1). Returns ------- x_0 : torch.Tensor, shape = (n_nodes, channels_node) Output features on the hypernodes. x_1 : torch.Tensor, shape = (n_edges, channels_edge) Output features on the hyperedges. """ if incidence_1 is not None: self.incidence_1 = incidence_1 self.incidence_1_transpose = incidence_1.transpose(1, 0) if self.use_normalized_incidence: self.n_nodes, self.n_edges = incidence_1.shape self.compute_normalization_matrices() self.normalize_incidence_matrices() # Move incidence matrices to device self.incidence_1 = self.incidence_1.to(x_0.device) self.incidence_1_transpose = self.incidence_1_transpose.to(x_0.device) # Compute output hyperedge features x_1 = self.conv_0_to_1(x_0, self.incidence_1_transpose) # nodes to edges if self.use_bias: x_1 += self.bias_0_to_1 # Compute output hypernode features x_0 = self.conv_1_to_0(x_1, self.incidence_1) # edges to nodes if self.use_bias: x_0 += self.bias_1_to_0 return (torch.relu(x_0), torch.relu(x_1))