Source code for topomodelx.nn.simplicial.san
"""Simplicial Attention Network (SAN) implementation for binary edge classification."""
import torch
from topomodelx.nn.simplicial.san_layer import SANLayer
[docs]
class SAN(torch.nn.Module):
"""Simplicial Attention Network (SAN) implementation for binary edge classification.
Parameters
----------
in_channels : int
Dimension of input features.
hidden_channels : int
Dimension of hidden features.
out_channels : int
Dimension of output features.
n_filters : int, default = 2
Approximation order for simplicial filters.
order_harmonic : int, default = 5
Approximation order for harmonic convolution.
epsilon_harmonic : float, default = 1e-1
Epsilon value for harmonic convolution.
n_layers : int, default = 2
Number of message passing layers.
"""
def __init__(
self,
in_channels,
hidden_channels,
out_channels=None,
n_filters=2,
order_harmonic=5,
epsilon_harmonic=1e-1,
n_layers=2,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = (
out_channels if out_channels is not None else hidden_channels
)
self.n_filters = n_filters
self.order_harmonic = order_harmonic
self.epsilon_harmonic = epsilon_harmonic
if n_layers == 1:
self.layers = [
SANLayer(
in_channels=self.in_channels,
out_channels=self.out_channels,
n_filters=self.n_filters,
)
]
else:
self.layers = [
SANLayer(
in_channels=self.in_channels,
out_channels=self.hidden_channels,
n_filters=self.n_filters,
)
]
for _ in range(n_layers - 2):
self.layers.append(
SANLayer(
in_channels=self.hidden_channels,
out_channels=self.hidden_channels,
n_filters=self.n_filters,
)
)
self.layers.append(
SANLayer(
in_channels=self.hidden_channels,
out_channels=self.out_channels,
n_filters=self.n_filters,
)
)
self.layers = torch.nn.ModuleList(self.layers)
[docs]
def compute_projection_matrix(self, laplacian):
"""Compute the projection matrix.
The matrix is used to calculate the harmonic component in SAN layers.
Parameters
----------
laplacian : torch.Tensor, shape = (n_edges, n_edges)
Hodge laplacian of rank 1.
Returns
-------
torch.Tensor, shape = (n_edges, n_edges)
Projection matrix.
"""
eye = torch.eye(laplacian.shape[0]).to(laplacian.device)
projection_mat = eye - self.epsilon_harmonic * laplacian
return torch.linalg.matrix_power(projection_mat, self.order_harmonic)
[docs]
def forward(self, x, laplacian_up, laplacian_down):
"""Forward computation.
Parameters
----------
x : torch.Tensor, shape = (n_nodes, channels_in)
Node features.
laplacian_up : torch.Tensor, shape = (n_edges, n_edges)
Upper laplacian of rank 1.
laplacian_down : torch.Tensor, shape = (n_edges, n_edges)
Down laplacian of rank 1.
Returns
-------
torch.Tensor, shape = (n_edges, out_channels)
Final hidden representations of edges.
"""
# Compute the projection matrix for the harmonic component
laplacian = laplacian_up + laplacian_down
projection_mat = self.compute_projection_matrix(laplacian)
# Forward computation
for layer in self.layers:
x = layer(x, laplacian_up, laplacian_down, projection_mat)
return x