Simplicial Attention Network (SAN) Layer.

class topomodelx.nn.simplicial.san_layer.SANConv(in_channels, out_channels, n_filters, initialization: Literal['xavier_uniform', 'xavier_normal'] = 'xavier_uniform')[source]#

Simplicial Attention Network (SAN) Convolution from [1].

Parameters:
in_channelsint

Number of input channels.

out_channelsint

Number of output channels.

n_filtersint

Number of simplicial filters.

initializationLiteral[“xavier_uniform”, “xavier_normal”], default=”xavier_uniform”

Weight initialization method.

References

[1]

Giusti, Battiloro, Di Lorenzo, Sardellitti and Barbarossa. Simplicial attention neural networks (2022). https://arxiv.org/abs/2203.07485.

[2]

Papillon, Sanborn, Hajij, Miolane. Equations of topological neural networks (2023). awesome-tnns/awesome-tnns

forward(x_source, neighborhood)[source]#

Forward pass.

This implements message passing: - from source cells with input features x_source, - via neighborhood defining where messages can pass, - to target cells, which are the same source cells.

In practice, this will update the features on the target cells [2]_.

\[\begin{split}\begin{align*} &🟥 \quad m_{y \rightarrow \{z\} \rightarrow x}^{u,(1 \rightarrow 2 \rightarrow 1)} = ((L_{\uparrow,1} \odot \operatorname{att}(h_z^{t,(2)}, h_y^{t,(1)}))^u)\_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,u}\\ &🟥 \quad m_{y \rightarrow \{z\} \rightarrow x}^{d,(1 \rightarrow 0 \rightarrow 1)} = ((L_{\downarrow,1} \odot \operatorname{att}(h_z^{t,(0)}, h_y^{t,(1)}))^d)\_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,d}\\ &🟥 \quad m^{p,(1 \rightarrow 1)}\_{y \rightarrow x} = ((1-wH_1)^p)\_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,p}\\ &🟧 \quad m_{x}^{u,(1 \rightarrow 2 \rightarrow 1)} = \sum_{y \in \mathcal{L}\_\uparrow(x)} m_{y \rightarrow \{z\} \rightarrow x}^{u,(1 \rightarrow 2 \rightarrow 1)}\\ &🟧 \quad m_{x}^{d,(1 \rightarrow 0 \rightarrow 1)} = \sum_{y \in \mathcal{L}\downarrow(X)} m_{y \rightarrow \{z\} \rightarrow x}^{d,(1 \rightarrow 0 \rightarrow 1)}\\ &🟧 \quad m^{p,(1 \rightarrow 1)}\_{x} = m^{p,(1 \rightarrow 1)}\_{x \rightarrow x}\\ &🟩 \quad m_x^{(1)} = \sum_{p=1}^P m_x^{p,(1 \rightarrow 1)} + \sum_{u=1}^{U} m_{x}^{u,(1 \rightarrow 2 \rightarrow 1)} + \sum_{d=1}^{D} m_{x}^{d,(1 \rightarrow 0 \rightarrow 1)}\\ &🟦 \quad h_x^{t+1, (1)} = \sigma(m_x^{(1)}) \end{align*}\end{split}\]
Parameters:
x_sourceTensor, shape = (…, n_source_cells, in_channels)

Input features on source cells. Assumes that all source cells have the same rank r.

neighborhoodtorch.sparse, shape = (n_target_cells, n_source_cells)

Neighborhood matrix.

Returns:
torch.Tensor, shape = (…, n_target_cells, out_channels)

Output features on target cells. Assumes that all target cells have the same rank s.

class topomodelx.nn.simplicial.san_layer.SANLayer(in_channels, out_channels, n_filters: int = 2)[source]#

Implementation of the Simplicial Attention Network (SAN) Layer proposed in [1]_.

Parameters:
in_channelsint

Number of input channels.

out_channelsint

Number of output channels.

n_filtersint, default = 2

Approximation order.

Notes

Architecture proposed for r-simplex (r>0) classification on simplicial complices.

forward(x, laplacian_up, laplacian_down, projection_mat)[source]#

Forward pass of the SAN Layer.

\[\mathcal N = \{\mathcal N_1, \mathcal N_2,...,\mathcal N_{2p+1}\} = \{A_{\uparrow, r}, A_{\downarrow, r}, A_{\uparrow, r}^2, A_{\downarrow, r}^2,...,A_{\uparrow, r}^p, A_{\downarrow, r}^p, Q_r\},\]
\[\begin{split}\begin{align*} &🟥\quad m_{(y \rightarrow x),k}^{(r)} = \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k \in \mathcal{N}\\ &🟧\quad m_{x,k}^{(r)} = \bigoplus_{y \in \mathcal{N}_k(x)} m^{(r)}_{(y \rightarrow x),k}\\ &🟩\quad m_{x}^{(r)} = \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}^{(r)}\\ &🟦\quad h_x^{t+1,(r)} = \phi^{t}(h_x^t, m_{x}^{(r)}) \end{align*}\end{split}\]
Parameters:
xtorch.Tensor, shape = (…, n_cells, in_channels)

Input tensor.

laplacian_uptorch.Tensor
laplacian_downtorch.Tensor

The up- and down-laplacians of the simplicial complex.

projection_mattorch.Tensor

The projection matrix used.

Returns:
torch.Tensor, shape = (…, n_cells, out_channels)

Output tensor.

reset_parameters() None[source]#

Reset learnable parameters.