Source code for topomodelx.nn.hypergraph.hmpnn
"""HMPNN class."""
import torch
from topomodelx.nn.hypergraph.hmpnn_layer import HMPNNLayer
[docs]
class HMPNN(torch.nn.Module):
"""Neural network implementation of HMPNN [1]_.
Parameters
----------
in_channels : int
Dimension of input features.
hidden_channels : Tuple[int]
A tuple of hidden feature dimensions to gradually reduce node/hyperedge representations feature
dimension from in_features to the last item in the tuple.
n_layers : int, default = 2
Number of HMPNNLayer layers.
adjacency_dropout_rate : int, default = 0.7
Adjacency dropout rate.
regular_dropout_rate : int, default = 0.5
Regular dropout rate applied on features.
**kwargs : optional
Additional arguments for the inner layers.
References
----------
.. [1] Heydari S, Livi L.
Message passing neural networks for hypergraphs.
ICANN 2022.
https://arxiv.org/abs/2203.16995
"""
def __init__(
self,
in_channels,
hidden_channels,
n_layers=2,
adjacency_dropout_rate=0.7,
regular_dropout_rate=0.5,
**kwargs,
):
super().__init__()
self.linear_node = torch.nn.Linear(in_channels, hidden_channels)
self.linear_edge = torch.nn.Linear(in_channels, hidden_channels)
self.layers = torch.nn.ModuleList(
[
HMPNNLayer(
hidden_channels,
adjacency_dropout=adjacency_dropout_rate,
updating_dropout=regular_dropout_rate,
**kwargs,
)
for _ in range(n_layers)
]
)
[docs]
def forward(self, x_0, x_1, incidence_1):
"""Forward computation through layers.
Parameters
----------
x_0 : torch.Tensor, shape = (n_nodes, in_features)
Node features.
x_1 : torch.Tensor, shape = (n_hyperedges, in_features)
Hyperedge features.
incidence_1 : torch.sparse.Tensor, shape = (n_nodes, n_hyperedges)
Incidence matrix (B1).
Returns
-------
x_0 : torch.Tensor
Output node features.
x_1 : torch.Tensor
Output hyperedge features.
"""
x_0 = self.linear_node(x_0)
x_1 = self.linear_edge(x_1)
for layer in self.layers:
x_0, x_1 = layer(x_0, x_1, incidence_1)
return x_0, x_1