Source code for topomodelx.nn.hypergraph.unigcnii_layer
"""UniGCNII layer implementation."""
import torch
from topomodelx.base.conv import Conv
[docs]
class UniGCNIILayer(torch.nn.Module):
r"""
Implementation of the UniGCNII layer [1]_.
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
alpha : float
The alpha parameter determining the importance of the self-loop (\theta_2).
beta : float
The beta parameter determining the importance of the learned matrix (\theta_1).
use_norm : bool, default=False
Whether to apply row normalization after the layer.
**kwargs : optional
Additional arguments for the layer modules.
References
----------
.. [1] Huang and Yang.
UniGNN: a unified framework for graph and hypergraph neural networks.
IJCAI 2021.
https://arxiv.org/pdf/2105.00956.pdf
"""
def __init__(
self,
in_channels,
hidden_channels,
alpha: float,
beta: float,
use_norm=False,
**kwargs,
) -> None:
super().__init__()
self.alpha = alpha
self.beta = beta
self.linear = torch.nn.Linear(in_channels, hidden_channels, bias=False)
self.conv = Conv(
in_channels=in_channels,
out_channels=in_channels,
with_linear_transform=False,
)
self.use_norm = use_norm
[docs]
def reset_parameters(self) -> None:
"""Reset the parameters of the layer."""
self.linear.reset_parameters()
[docs]
def forward(self, x_0, incidence_1, x_skip=None):
r"""Forward pass of the UniGCNII layer.
The forward pass consists of:
- two messages, and
- a skip connection with a learned update function.
First every hyper-edge sums up the features of its constituent edges:
.. math::
\begin{align*}
& 🟥 \quad m_{y \rightarrow z}^{(0 \rightarrow 1)} = (B^T_1)\_{zy} \cdot h^{t,(0)}_y \\
& 🟧 \quad m_z^{(0\rightarrow1)} = \sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)}
\end{align*}
Second, the second message is normalized with the node and edge degrees:
.. math::
\begin{align*}
& 🟥 \quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = B_1 \cdot m_z^{(0 \rightarrow 1)} \\
& 🟧 \quad m_{x}^{(1\rightarrow0)} = \frac{1}{\sqrt{d_x}}\sum_{z \in \mathcal{C}(x)} \frac{1}{\sqrt{d_z}}m_{z \rightarrow x}^{(1\rightarrow0)} \\
\end{align*}
Third, the computed message is combined with skip connections and a linear transformation using hyperparameters alpha and beta:
.. math::
\begin{align*}
& 🟩 \quad m_x^{(0)} = m_x^{(1 \rightarrow 0)} \\
& 🟦 \quad m_x^{(0)} = ((1-\beta)I + \beta W)((1-\alpha)m_x^{(0)} + \alpha \cdot h_x^{t,(0)}) \\
\end{align*}
Parameters
----------
x_0 : torch.Tensor, shape = (num_nodes, in_channels)
Input features of the nodes of the hypergraph.
incidence_1 : torch.Tensor, shape = (num_nodes, num_edges)
Incidence matrix of the hypergraph.
It is expected that the incidence matrix contains self-loops for all nodes.
x_skip : torch.Tensor, shape = (num_nodes, in_channels)
Original node features of the hypergraph used for the skip connections.
If not provided, the input to the layer is used as a skip connection.
Returns
-------
x_0 : torch.Tensor
Output node features.
x_1 : torch.Tensor
Output hyperedge features.
"""
x_skip = x_0 if x_skip is None else x_skip
incidence_1_transpose = incidence_1.transpose(0, 1)
# First message without any learning or parameters
x_1 = self.conv(x_0, incidence_1_transpose)
# Compute node and edge degrees for normalization.
node_degree = torch.sum(incidence_1.to_dense(), dim=1)
# check if the node degrees are positive
assert torch.all(
node_degree > 0
), "Node degrees should be positive (at least self-loops should be included).)"
# Average node degree for each edge.
edge_degree = torch.sum(torch.diag(node_degree) @ incidence_1, dim=0)
assert torch.all(
edge_degree > 0
), "Edge degrees should be positive (every edge needs at least one node it is connecting)."
edge_degree = edge_degree / torch.sum(incidence_1.to_dense(), dim=0)
# Second message normalized with node and edge degrees (using broadcasting)
x_0 = (1 / torch.sqrt(node_degree).unsqueeze(-1)) * self.conv(
x_1, incidence_1 @ torch.diag(1 / torch.sqrt(edge_degree))
)
# Introduce skip connections with hyperparameter alpha and beta
x_combined = ((1 - self.alpha) * x_0) + (self.alpha * x_skip)
x_0 = ((1 - self.beta) * x_combined) + self.beta * self.linear(x_combined)
if self.use_norm:
rownorm = x_0.detach().norm(dim=1, keepdim=True)
scale = rownorm.pow(-1)
scale[torch.isinf(scale)] = 0.0
x_0 = x_0 * scale
return x_0, x_1