Source code for topomodelx.nn.hypergraph.hypersage
"""HyperSAGE Layer."""
import torch
from topomodelx.nn.hypergraph.hypersage_layer import HyperSAGELayer
[docs]
class HyperSAGE(torch.nn.Module):
"""Neural network implementation of HyperSAGE [1]_ for hypergraph classification.
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
n_layers : int, default = 2
Amount of message passing layers.
alpha : int, default = -1
Max number of nodes in a neighborhood to consider. If -1 it considers all the nodes.
**kwargs : optional
Additional arguments for the inner layers.
References
----------
.. [1] Arya, Gupta, Rudinac and Worring.
HyperSAGE: Generalizing inductive representation learning on hypergraphs (2020).
https://arxiv.org/abs/2010.04558
"""
def __init__(self, in_channels, hidden_channels, n_layers=2, alpha=-1, **kwargs):
super().__init__()
self.layers = torch.nn.ModuleList(
HyperSAGELayer(
in_channels=in_channels if i == 0 else hidden_channels,
out_channels=hidden_channels,
alpha=alpha,
**kwargs,
)
for i in range(n_layers)
)
[docs]
def forward(self, x_0, incidence_1):
"""Forward computation through layers, then linear layer, then global max pooling.
Parameters
----------
x_0 : torch.Tensor, shape = (n_nodes, features_nodes)
Edge features.
incidence_1 : torch.Tensor, shape = (n_nodes, n_edges)
Boundary matrix of rank 1.
Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
"""
for layer in self.layers:
x_0 = layer.forward(x_0, incidence_1)
return x_0