Source code for topomodelx.nn.simplicial.san_layer

"""Simplicial Attention Network (SAN) Layer."""
from typing import Literal

import torch
from torch.nn.parameter import Parameter

from topomodelx.base.conv import Conv

[docs] class SANConv(Conv): r"""Simplicial Attention Network (SAN) Convolution from [1]_. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. n_filters : int Number of simplicial filters. initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Weight initialization method. References ---------- .. [1] Giusti, Battiloro, Di Lorenzo, Sardellitti and Barbarossa. Simplicial attention neural networks (2022). .. [2] Papillon, Sanborn, Hajij, Miolane. Equations of topological neural networks (2023). """ def __init__( self, in_channels, out_channels, n_filters, initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", ) -> None: super(Conv, self).__init__( att=True, initialization=initialization, ) self.in_channels = in_channels self.out_channels = out_channels self.n_filters = n_filters self.initialization = initialization self.weight = Parameter( torch.Tensor(self.n_filters, self.in_channels, self.out_channels) ) self.att_weight = Parameter( torch.Tensor( 2 * self.out_channels * self.n_filters, ) ) self.reset_parameters()
[docs] def forward(self, x_source, neighborhood): r"""Forward pass. This implements message passing: - from source cells with input features `x_source`, - via `neighborhood` defining where messages can pass, - to target cells, which are the same source cells. In practice, this will update the features on the target cells [2]_. .. math:: \begin{align*} &🟥 \quad m_{y \rightarrow \{z\} \rightarrow x}^{u,(1 \rightarrow 2 \rightarrow 1)} = ((L_{\uparrow,1} \odot \operatorname{att}(h_z^{t,(2)}, h_y^{t,(1)}))^u)\_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,u}\\ &🟥 \quad m_{y \rightarrow \{z\} \rightarrow x}^{d,(1 \rightarrow 0 \rightarrow 1)} = ((L_{\downarrow,1} \odot \operatorname{att}(h_z^{t,(0)}, h_y^{t,(1)}))^d)\_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,d}\\ &🟥 \quad m^{p,(1 \rightarrow 1)}\_{y \rightarrow x} = ((1-wH_1)^p)\_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,p}\\ &🟧 \quad m_{x}^{u,(1 \rightarrow 2 \rightarrow 1)} = \sum_{y \in \mathcal{L}\_\uparrow(x)} m_{y \rightarrow \{z\} \rightarrow x}^{u,(1 \rightarrow 2 \rightarrow 1)}\\ &🟧 \quad m_{x}^{d,(1 \rightarrow 0 \rightarrow 1)} = \sum_{y \in \mathcal{L}\downarrow(X)} m_{y \rightarrow \{z\} \rightarrow x}^{d,(1 \rightarrow 0 \rightarrow 1)}\\ &🟧 \quad m^{p,(1 \rightarrow 1)}\_{x} = m^{p,(1 \rightarrow 1)}\_{x \rightarrow x}\\ &🟩 \quad m_x^{(1)} = \sum_{p=1}^P m_x^{p,(1 \rightarrow 1)} + \sum_{u=1}^{U} m_{x}^{u,(1 \rightarrow 2 \rightarrow 1)} + \sum_{d=1}^{D} m_{x}^{d,(1 \rightarrow 0 \rightarrow 1)}\\ &🟦 \quad h_x^{t+1, (1)} = \sigma(m_x^{(1)}) \end{align*} Parameters ---------- x_source : Tensor, shape = (..., n_source_cells, in_channels) Input features on source cells. Assumes that all source cells have the same rank r. neighborhood : torch.sparse, shape = (n_target_cells, n_source_cells) Neighborhood matrix. Returns ------- torch.Tensor, shape = (..., n_target_cells, out_channels) Output features on target cells. Assumes that all target cells have the same rank s. """ x_message = torch.matmul(x_source, self.weight) # Reshape required to re-use the attention function of parent Conv class # -> [num_nodes, out_channels * n_filters] x_message_reshaped = x_message.permute(1, 0, 2).reshape( -1, self.out_channels * self.n_filters ) # SAN always requires attention # In SAN, neighborhood is defined by lower/upper laplacians; we only use them as masks # to keep only the relevant attention coeffs neighborhood = neighborhood.coalesce() self.target_index_i, self.source_index_j = neighborhood.indices() attention_values = self.attention(x_message_reshaped) att_laplacian = torch.sparse_coo_tensor( indices=neighborhood.indices(), values=attention_values, size=neighborhood.shape, ) # Attention coeffs are normalized using softmax att_laplacian = torch.sparse.softmax(att_laplacian, dim=1).to_dense() # We need to compute the power of the attention laplacian according up to order p att_laplacian_power = [att_laplacian] for _ in range(1, self.n_filters): att_laplacian_power.append( torch.matmul(att_laplacian_power[-1], att_laplacian) ) att_laplacian_power = torch.stack(att_laplacian_power) # When computing the final message on targets, we multiply the message by each power # of the attention laplacian and sum the results return torch.matmul(att_laplacian_power, x_message).sum(dim=0)
[docs] class SANLayer(torch.nn.Module): r"""Implementation of the Simplicial Attention Network (SAN) Layer proposed in [1]_. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. n_filters : int, default = 2 Approximation order. Notes ----- Architecture proposed for r-simplex (r>0) classification on simplicial complices. """ def __init__( self, in_channels, out_channels, n_filters: int = 2, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.n_filters = n_filters # Convolutions # Down convolutions, one for each filter order p self.conv_down = SANConv(in_channels, out_channels, n_filters) # Up convolutions, one for each filter order p self.conv_up = SANConv(in_channels, out_channels, n_filters) # Harmonic convolution self.conv_harmonic = Conv(in_channels, out_channels)
[docs] def reset_parameters(self) -> None: r"""Reset learnable parameters.""" self.conv_down.reset_parameters() self.conv_up.reset_parameters() self.conv_harmonic.reset_parameters()
[docs] def forward(self, x, laplacian_up, laplacian_down, projection_mat): r"""Forward pass of the SAN Layer. .. math:: \mathcal N = \{\mathcal N_1, \mathcal N_2,...,\mathcal N_{2p+1}\} = \{A_{\uparrow, r}, A_{\downarrow, r}, A_{\uparrow, r}^2, A_{\downarrow, r}^2,...,A_{\uparrow, r}^p, A_{\downarrow, r}^p, Q_r\}, .. math:: \begin{align*} &🟥\quad m_{(y \rightarrow x),k}^{(r)} = \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k \in \mathcal{N}\\ &🟧\quad m_{x,k}^{(r)} = \bigoplus_{y \in \mathcal{N}_k(x)} m^{(r)}_{(y \rightarrow x),k}\\ &🟩\quad m_{x}^{(r)} = \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}^{(r)}\\ &🟦\quad h_x^{t+1,(r)} = \phi^{t}(h_x^t, m_{x}^{(r)}) \end{align*} Parameters ---------- x : torch.Tensor, shape = (..., n_cells, in_channels) Input tensor. laplacian_up : torch.Tensor laplacian_down : torch.Tensor The up- and down-laplacians of the simplicial complex. projection_mat : torch.Tensor The projection matrix used. Returns ------- torch.Tensor, shape = (..., n_cells, out_channels) Output tensor. """ # Compute the down and up convolutions z_down = self.conv_down(x, laplacian_down) z_up = self.conv_up(x, laplacian_up) # For the harmonic convolution, we use the precomputed projection matrix P as the neighborhood # with no attention z_har = self.conv_harmonic(x, projection_mat) # final output return z_down + z_up + z_har