Source code for topomodelx.nn.hypergraph.allset
"""Allset class."""
import torch
from topomodelx.nn.hypergraph.allset_layer import AllSetLayer
[docs]
class AllSet(torch.nn.Module):
"""AllSet Neural Network Module.
A module that combines multiple AllSet layers [1]_ to form a neural network.
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
n_layers : int, default = 2
Number of AllSet layers in the network.
layer_dropout : float, default = 0.2
Dropout probability for the AllSet layer.
mlp_num_layers : int, default = 2
Number of layers in the MLP.
mlp_activation : torch.nn.Module, default = None
Activation function in the MLP.
mlp_dropout : float, default = 0.0
Dropout probability for the MLP.
mlp_norm : bool, default = False
Whether to apply input normalization in the MLP.
**kwargs : optional
Additional arguments for the inner layers.
References
----------
.. [1] Chien, Pan, Peng and Milenkovic.
You are AllSet: a multiset function framework for hypergraph neural networks.
ICLR 2022.
https://arxiv.org/abs/2106.13264
"""
def __init__(
self,
in_channels,
hidden_channels,
n_layers=2,
layer_dropout=0.2,
mlp_num_layers=2,
mlp_activation=None,
mlp_dropout=0.0,
mlp_norm=None,
**kwargs,
):
super().__init__()
self.layers = torch.nn.ModuleList(
AllSetLayer(
in_channels=in_channels if i == 0 else hidden_channels,
hidden_channels=hidden_channels,
dropout=layer_dropout,
mlp_num_layers=mlp_num_layers,
mlp_activation=mlp_activation,
mlp_dropout=mlp_dropout,
mlp_norm=mlp_norm,
**kwargs,
)
for i in range(n_layers)
)
[docs]
def forward(self, x_0, incidence_1):
"""Forward computation.
Parameters
----------
x_0 : torch.Tensor
Input features.
incidence_1 : torch.Tensor
Edge list (of size (2, |E|)).
Returns
-------
x_0 : torch.Tensor
Output node features.
x_1 : torch.Tensor
Output hyperedge features.
"""
for layer in self.layers:
x_0, x_1 = layer(x_0, incidence_1)
return x_0, x_1