"""Simplicial Complex Convolutional Neural Network Layer."""
import torch
from torch.nn.parameter import Parameter
[docs]
class SCCNNLayer(torch.nn.Module):
r"""Layer of a Simplicial Complex Convolutional Neural Network.
Parameters
----------
in_channels : tuple of int
Dimensions of input features on nodes, edges, and triangles.
out_channels : tuple of int
Dimensions of output features on nodes, edges, and triangles.
conv_order : int
Convolution order of the simplicial filters.
To avoid too many parameters, we consider them to be the same.
sc_order : int
SC order.
aggr_norm : bool, default = False
Whether to normalize the aggregated message by the neighborhood size.
update_func : str, default = None
Activation function used in aggregation layers.
initialization : str, default = "xavier_normal"
Weight initialization method.
Examples
--------
Here we provide an example of pseudocode for SCCNN layer in an SC
of order two
input X_0: [n_nodes, in_channels]
input X_1: [n_edges, in_channels]
input X_2: [n_faces, in_channels]
graph Laplacian L_0: [n_nodes, n_nodes]
1-Lap_down L_1_down: [n_edges, n_edges]
1-Lap_up L_1_up: [n_edges, n_edges]
2-Lap L_2: [n_faces,n_faces]
1-incidence B_1: [n_nodes, n_edges]
2-incidence B_2: [n_edges, n_faces]
conv_order: int, e.g., 2
output Y_0: [n_nodes, out_channels]
output Y_1: [n_edges, out_channels]
output Y_2: [n_faces, out_channels]
SCCNN layer looks like:
Y_0 = torch.einsum(
concat(
X_0, L_0@X_0, L_0@L_0@X_0 ||
B_1@X_1, B_1@L_1_down@X_1, B_1@L_1_down@L_1_down@X_1
), weight_0)
Y_1 = torch.einsum(
concat(
B_1.T@X_1, B_1.T@L_0@X_0, B_1.T@L_0@L_0@X_0 ||
X_1, L_1_down@X_1, L_1_down@L_1_down@X_1,
L_1_up@X_1, L_1_up@L_1_up@X_1 ||
B_2@X_2, B_2@L_2@X_2, B_2@L_2@L_2@X_2
), weight_1)
Y_2 = torch.einsum(
concat(
X_2, L_2@X_2, L_2@L_2@X_2 ||
B_2.T@X_1, B_2.T@L_1_up@X_1, B_2.T@L_1_up@L_1_up@X_1
), weight_2)
where
- weight_0, weight_2, weight_2 are the trainable parameters
- weight_0: [out_channels, in_channels, total_order_0]
- total_order_0 = 1+conv_order + 1+conv_order
- weight_1: [out_channels, in_channels, total_order_1]
- total_order_1 = 1+conv_order +
1+conv_order+conv_order +
1+conv_order
- weight_2: [out_channels, in_channels, total_order_2]
- total_order_2 = 1+conv_order + 1+conv_order
- to implement Lap_down@Lap_down@X, we consider chebyshev method
to avoid matrix@matrix computation
References
----------
.. [1] Papillon, Sanborn, Hajij, Miolane.
Equations of topological neural networks (2023).
https://github.com/awesome-tnns/awesome-tnns/
"""
def __init__(
self,
in_channels,
out_channels,
conv_order,
sc_order,
aggr_norm: bool = False,
update_func=None,
initialization: str = "xavier_normal",
) -> None:
super().__init__()
in_channels_0, in_channels_1, in_channels_2 = in_channels
out_channels_0, out_channels_1, out_channels_2 = out_channels
self.in_channels_0 = in_channels_0
self.in_channels_1 = in_channels_1
self.in_channels_2 = in_channels_2
self.out_channels_0 = out_channels_0
self.out_channels_1 = out_channels_1
self.out_channels_2 = out_channels_2
self.conv_order = conv_order
self.sc_order = sc_order
self.aggr_norm = aggr_norm
self.update_func = update_func
self.initialization = initialization
assert initialization in ["xavier_uniform", "xavier_normal"]
assert self.conv_order > 0
self.weight_0 = Parameter(
torch.Tensor(
self.in_channels_0, self.out_channels_0, 1 + conv_order + 1 + conv_order
)
)
self.weight_1 = Parameter(
torch.Tensor(
self.in_channels_1,
self.out_channels_1,
1 + conv_order + 1 + conv_order + conv_order + 1 + conv_order,
)
)
# determine the third dimensions of the weights
# because when SC order is larger than 2, there are lower and upper
# parts for L_2; otherwise, L_2 contains only the lower part
if sc_order > 2:
self.weight_2 = Parameter(
torch.Tensor(
self.in_channels_2,
self.out_channels_2,
1 + conv_order + 1 + conv_order + conv_order,
)
)
elif sc_order == 2:
self.weight_2 = Parameter(
torch.Tensor(
self.in_channels_2,
self.out_channels_2,
1 + conv_order + 1 + conv_order,
)
)
self.reset_parameters()
[docs]
def reset_parameters(self, gain: float = 1.414):
r"""Reset learnable parameters.
Parameters
----------
gain : float
Gain for the weight initialization.
Notes
-----
This function will be called by subclasses of
MessagePassing that have trainable weights.
"""
if self.initialization == "xavier_uniform":
torch.nn.init.xavier_uniform_(self.weight_0, gain=gain)
torch.nn.init.xavier_uniform_(self.weight_1, gain=gain)
torch.nn.init.xavier_uniform_(self.weight_2, gain=gain)
elif self.initialization == "xavier_normal":
torch.nn.init.xavier_normal_(self.weight_0, gain=gain)
torch.nn.init.xavier_normal_(self.weight_1, gain=gain)
torch.nn.init.xavier_normal_(self.weight_2, gain=gain)
else:
raise RuntimeError(
"Initialization method not recognized. "
"Should be either xavier_uniform or xavier_normal."
)
[docs]
def aggr_norm_func(self, conv_operator, x):
r"""Perform aggregation normalization."""
neighborhood_size = torch.sum(conv_operator.to_dense(), dim=1)
neighborhood_size_inv = 1 / neighborhood_size
neighborhood_size_inv[~(torch.isfinite(neighborhood_size_inv))] = 0
x = torch.einsum("i,ij->ij ", neighborhood_size_inv, x)
x[~torch.isfinite(x)] = 0
return x
[docs]
def update(self, x):
"""Update embeddings on each cell (step 4).
Parameters
----------
x : torch.Tensor, shape = (n_target_cells, out_channels)
Feature tensor.
Returns
-------
torch.Tensor, shape = (n_target_cells, out_channels)
Updated output features on target cells.
"""
if self.update_func == "sigmoid":
return torch.sigmoid(x)
if self.update_func == "relu":
return torch.nn.functional.relu(x)
return None
[docs]
def chebyshev_conv(self, conv_operator, conv_order, x):
r"""Perform Chebyshev convolution.
Parameters
----------
conv_operator : torch.sparse, shape = (n_simplices,n_simplices)
Convolution operator e.g., the adjacency matrix, or the Hodge Laplacians.
conv_order : int
The order of the convolution.
x : torch.Tensor, shape = (n_simplices,num_channels)
Feature tensor.
Returns
-------
torch.Tensor
Output tensor. x[:, :, k] = (conv_operator@....@conv_operator) @ x.
"""
num_simplices, num_channels = x.shape
X = torch.empty(size=(num_simplices, num_channels, conv_order))
if self.aggr_norm:
X[:, :, 0] = torch.mm(conv_operator, x)
X[:, :, 0] = self.aggr_norm_func(conv_operator, X[:, :, 0])
for k in range(1, conv_order):
X[:, :, k] = torch.mm(conv_operator, X[:, :, k - 1])
X[:, :, k] = self.aggr_norm_func(conv_operator, X[:, :, k])
else:
X[:, :, 0] = torch.mm(conv_operator, x)
for k in range(1, conv_order):
X[:, :, k] = torch.mm(conv_operator, X[:, :, k - 1])
return X
[docs]
def forward(self, x_all, laplacian_all, incidence_all):
r"""Forward computation (see [1]_).
.. math::
\begin{align*}
&🟥 \quad m_{y \rightarrow z}^{(0\rightarrow1)} = B_1^T \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow 1)}\\
&🟧 $\quad m_{z}^{(0\rightarrow1)} = \frac{1}\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0\rightarrow1)} \qquad \text{where} \sum \text{represents a mean.}\\
&🟥 $\quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = B_1\odot att(m_{z \in \mathcal{C}(x)}^{(0\rightarrow1)}, h_x^{t,(0)}) \cdot m_z^{(0\rightarrow1)} \cdot \Theta^{t,(1 \rightarrow 0)}\\
&🟧 $\quad m_x^{(1\rightarrow0)} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1\rightarrow0)} \qquad \text{where} \sum \text{represents a mean.}\\
&🟩 \quad m_x^{(0)} = m_x^{(1\rightarrow0)}\\
&🟦 \quad h_x^{t+1, (0)} = \Theta^{t, \text{update}} \cdot (h_x^{t,(0)}||m_x^{(0)})+b^{t, \text{update}}\\
\end{align*}
Parameters
----------
x_all : tuple of tensors, shape = (x_0,x_1,x_2)
Tuple of input feature tensors:
- x_0: torch.Tensor, shape = (n_nodes,in_channels_0),
- x_1: torch.Tensor, shape = (n_edges,in_channels_1),
- x_2: torch.Tensor, shape = (n_triangles,in_channels_2).
laplacian_all: tuple of tensors, shape = (laplacian_0,laplacian_down_1,laplacian_up_1,laplacian_2)
Tuple of laplacian tensors:
- laplacian_0: torch.sparse, graph Laplacian,
- laplacian_down_1: torch.sparse, the 1-Hodge laplacian (lower part),
- laplacian_up_1: torch.sparse, the 1-hodge laplacian (upper part),
- laplacian_2: torch.sparse, the 2-hodge laplacian.
incidence_all : tuple of tensors, shape = (b1,b2)
Tuple of incidence tensors:
- b1: torch.sparse, shape = (n_nodes,n_edges), node-to-edge incidence matrix,
- b2: torch.sparse, shape = (n_edges,n_triangles), edge-to-face incidence matrix.
Returns
-------
y_0 : torch.Tensor
Output features on nodes.
y_1 : torch.Tensor
Output features on edges.
y_2 : torch.Tensor
Output features on triangles.
"""
x_0, x_1, x_2 = x_all
if self.sc_order == 2:
laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = laplacian_all
elif self.sc_order > 2:
(
laplacian_0,
laplacian_down_1,
laplacian_up_1,
laplacian_down_2,
laplacian_up_2,
) = laplacian_all
num_nodes, num_edges, num_triangles = x_0.shape[0], x_1.shape[0], x_2.shape[0]
b1, b2 = incidence_all
identity_0, identity_1, identity_2 = (
torch.eye(num_nodes),
torch.eye(num_edges),
torch.eye(num_triangles),
)
"""
convolution in the node space
"""
x_identity_0 = torch.unsqueeze(identity_0 @ x_0, 2)
x_0_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_0)
x_0_to_0 = torch.cat((x_identity_0, x_0_to_0), 2)
x_1_to_0 = torch.mm(b1, x_1)
x_1_to_0_identity = torch.unsqueeze(identity_0 @ x_1_to_0, 2)
x_1_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_1_to_0)
x_1_to_0 = torch.cat((x_1_to_0_identity, x_1_to_0), 2)
x_0_all = torch.cat((x_0_to_0, x_1_to_0), 2)
"""
convolution in the edge space
"""
x_identity_1 = torch.unsqueeze(identity_1 @ x_1, 2)
x_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1)
x_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_1)
x_1_to_1 = torch.cat((x_identity_1, x_1_down, x_1_up), 2)
x_0_to_1 = torch.mm(b1.T, x_0)
x_0_to_1_identity = torch.unsqueeze(identity_1 @ x_0_to_1, 2)
x_0_to_1 = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_to_1)
x_0_to_1 = torch.cat((x_0_to_1_identity, x_0_to_1), 2)
x_2_to_1 = torch.mm(b2, x_2)
x_2_to_1_identity = torch.unsqueeze(identity_1 @ x_2_to_1, 2)
x_2_to_1 = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_2_to_1)
x_2_to_1 = torch.cat((x_2_to_1_identity, x_2_to_1), 2)
x_1_all = torch.cat((x_0_to_1, x_1_to_1, x_2_to_1), 2)
"""
convolution in the face (triangle) space, depending on the SC order,
the exact form maybe a little different
"""
x_identity_2 = torch.unsqueeze(identity_2 @ x_2, 2)
if self.sc_order == 2:
x_2 = self.chebyshev_conv(laplacian_2, self.conv_order, x_2)
x_2_to_2 = torch.cat((x_identity_2, x_2), 2)
elif self.sc_order > 2:
x_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_2)
x_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_2)
x_2_to_2 = torch.cat((x_identity_2, x_2_down, x_2_up), 2)
x_1_to_2 = torch.mm(b2.T, x_1)
x_1_to_2_identity = torch.unsqueeze(identity_2 @ x_1_to_2, 2)
if self.sc_order == 2:
x_1_to_2 = self.chebyshev_conv(laplacian_2, self.conv_order, x_1_to_2)
elif self.sc_order > 2:
x_1_to_2 = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_1_to_2)
x_1_to_2 = torch.cat((x_1_to_2_identity, x_1_to_2), 2)
x_2_all = torch.cat((x_2_to_2, x_1_to_2), 2)
y_0 = torch.einsum("nik,iok->no", x_0_all, self.weight_0)
y_1 = torch.einsum("nik,iok->no", x_1_all, self.weight_1)
y_2 = torch.einsum("nik,iok->no", x_2_all, self.weight_2)
if self.update_func is None:
return y_0, y_1, y_2
return self.update(y_0), self.update(y_1), self.update(y_2)