"""AllSet Layer Module."""
import torch.nn.functional as F
from torch import nn
from topomodelx.base.conv import Conv
[docs]
class AllSetLayer(nn.Module):
"""
AllSet Layer Module [1]_.
A module for AllSet layer in a bipartite graph.
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
dropout : float, default=0.2
Dropout probability.
mlp_num_layers : int, default=2
Number of layers in the MLP.
mlp_activation : callable or None, optional
Activation function in the MLP.
mlp_dropout : float, optional
Dropout probability in the MLP.
mlp_norm : str or None, optional
Type of layer normalization in the MLP.
**kwargs : optional
Additional arguments for the layer modules.
References
----------
.. [1] Chien, Pan, Peng and Milenkovic.
You are AllSet: a multiset function framework for hypergraph neural networks.
ICLR 2022.
https://arxiv.org/abs/2106.13264
"""
def __init__(
self,
in_channels,
hidden_channels,
dropout: float = 0.2,
mlp_num_layers: int = 2,
mlp_activation=nn.ReLU,
mlp_dropout: float = 0.0,
mlp_norm=None,
**kwargs,
) -> None:
super().__init__()
if mlp_num_layers <= 0:
raise ValueError(f"mlp_num_layers ({mlp_num_layers}) must be positive")
self.dropout = dropout
self.vertex2edge = AllSetBlock(
in_channels=in_channels,
hidden_channels=hidden_channels,
dropout=dropout,
mlp_num_layers=mlp_num_layers,
mlp_activation=mlp_activation,
mlp_dropout=mlp_dropout,
mlp_norm=mlp_norm,
**kwargs,
)
self.edge2vertex = AllSetBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
dropout=dropout,
mlp_num_layers=mlp_num_layers,
mlp_activation=mlp_activation,
mlp_dropout=mlp_dropout,
mlp_norm=mlp_norm,
**kwargs,
)
[docs]
def reset_parameters(self) -> None:
"""Reset learnable parameters."""
self.vertex2edge.reset_parameters()
self.edge2vertex.reset_parameters()
[docs]
def forward(self, x_0, incidence_1):
r"""
Forward computation.
Vertex to edge:
.. math::
\begin{align*}
&🟧 \quad m_{\rightarrow z}^{(\rightarrow 1)}
= AGG_{y \in \mathcal{B}(z)} (h_y^{t, (0)}, h_z^{t,(1)}) \\
&🟦 \quad h_z^{t+1,(1)}
= \sigma(m_{\rightarrow z}^{(\rightarrow 1)})
\end{align*}
Edge to vertex:
.. math::
\begin{align*}
&🟧 \quad m_{\rightarrow x}^{(\rightarrow 0)}
= AGG_{z \in \mathcal{C}(x)} (h_z^{t+1,(1)}, h_x^{t,(0)}) \\
&🟦 \quad h_x^{t+1,(0)}
= \sigma(m_{\rightarrow x}^{(\rightarrow 0)})
\end{align*}
Parameters
----------
x_0 : torch.Tensor, shape = (n_nodes, channels)
Node input features.
incidence_1 : torch.sparse, shape = (n_nodes, n_hyperedges)
Incidence matrix :math:`B_1` mapping hyperedges to nodes.
Returns
-------
x_0 : torch.Tensor
Output node features.
x_1 : torch.Tensor
Output hyperedge features.
"""
if x_0.shape[-2] != incidence_1.shape[-2]:
raise ValueError(
f"Shape of incidence matrix ({incidence_1.shape}) does not have the correct number of nodes ({x_0.shape[0]})."
)
x_1 = self.vertex2edge(x_0, incidence_1.transpose(1, 0))
x_1 = F.dropout(x_1, p=self.dropout, training=self.training)
x_0 = self.edge2vertex(x_1, incidence_1)
x_0 = F.dropout(x_0, p=self.dropout, training=self.training)
return x_0, x_1
[docs]
class MLP(nn.Sequential):
"""MLP Module.
A module for a multi-layer perceptron (MLP).
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : list of int
List of dimensions of the hidden features.
norm_layer : callable or None, optional
Type of layer normalization.
activation_layer : callable or None, optional
Type of activation function.
dropout : float, default=0.0
Dropout probability.
inplace : bool, default=False
Whether to do the operation in-place.
bias : bool, default=False
Whether to add bias.
"""
def __init__(
self,
in_channels,
hidden_channels,
norm_layer=None,
activation_layer=None,
dropout: float = 0.0,
inplace: bool | None = None,
bias: bool = False,
) -> None:
params = {} if inplace is None else {"inplace": inplace}
layers: list[nn.Module] = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(nn.Dropout(dropout, **params))
in_dim = hidden_dim
layers.append(nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(nn.Dropout(dropout, **params))
super().__init__(*layers)
[docs]
class AllSetBlock(nn.Module):
"""AllSet Block Module.
A module for AllSet block in a bipartite graph.
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
dropout : float, default=0.2
Dropout probability.
mlp_num_layers : int, default=2
Number of layers in the MLP.
mlp_activation : callable or None, optional
Activation function in the MLP.
mlp_dropout : float, optional
Dropout probability in the MLP.
mlp_norm : callable or None, optional
Type of layer normalization in the MLP.
**kwargs : optional
Additional arguments for the block modules.
"""
encoder: MLP | nn.Identity
decoder: MLP | nn.Identity
def __init__(
self,
in_channels,
hidden_channels,
dropout: float = 0.2,
mlp_num_layers: int = 2,
mlp_activation=nn.ReLU,
mlp_dropout: float = 0.0,
mlp_norm=None,
**kwargs,
) -> None:
super().__init__()
self.dropout = dropout
if mlp_num_layers > 0:
mlp_hidden_channels = [hidden_channels] * mlp_num_layers
self.encoder = MLP(
in_channels,
mlp_hidden_channels,
norm_layer=mlp_norm,
activation_layer=mlp_activation,
dropout=mlp_dropout,
)
self.decoder = MLP(
hidden_channels,
mlp_hidden_channels,
norm_layer=mlp_norm,
activation_layer=mlp_activation,
dropout=mlp_dropout,
)
in_channels = hidden_channels
else:
self.encoder = nn.Identity()
self.decoder = nn.Identity()
self.conv = Conv(
in_channels=in_channels,
out_channels=hidden_channels,
aggr_norm=True,
update_func="relu",
att=False,
)
[docs]
def reset_parameters(self) -> None:
"""Reset learnable parameters."""
if callable(self.encoder.reset_parameters):
self.encoder.reset_parameters()
if callable(self.decoder.reset_parameters):
self.decoder.reset_parameters()
self.conv.reset_parameters()
[docs]
def forward(self, x_0, incidence_1):
"""
Forward computation.
Parameters
----------
x_0 : torch.Tensor
Input node features.
incidence_1 : torch.sparse
Incidence matrix between node/hyperedges.
Returns
-------
torch.Tensor
Output features.
"""
x = F.relu(self.encoder(x_0))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv(x, incidence_1)
return F.relu(self.decoder(x))