Source code for topomodelx.nn.hypergraph.hnhn

"""HNHN class."""

import torch

from topomodelx.nn.hypergraph.hnhn_layer import HNHNLayer


[docs] class HNHN(torch.nn.Module): """Hypergraph Networks with Hyperedge Neurons [1]_. Implementation for multiclass node classification. Parameters ---------- in_channels : int Dimension of the input features. hidden_channels : int Dimension of the hidden features. incidence_1 : torch.sparse, shape = (n_nodes, n_edges) Incidence matrix mapping edges to nodes (B_1). n_layers : int, default = 2 Number of HNHN message passing layers. layer_drop : float, default = 0.2 Dropout rate for the hidden features. **kwargs : optional Additional arguments for the inner layers. References ---------- .. [1] Dong, Sawin, Bengio. HNHN: hypergraph networks with hyperedge neurons. Graph Representation Learning and Beyond Workshop at ICML 2020. https://grlplus.github.io/papers/40.pdf """ def __init__( self, in_channels, hidden_channels, incidence_1, n_layers=2, layer_drop=0.2, **kwargs, ): super().__init__() self.layers = torch.nn.ModuleList( HNHNLayer( in_channels=in_channels if i == 0 else hidden_channels, hidden_channels=hidden_channels, incidence_1=incidence_1, **kwargs, ) for i in range(n_layers) ) self.layer_drop = torch.nn.Dropout(layer_drop)
[docs] def forward(self, x_0, incidence_1=None): """Forward computation. Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, channels_node) Hypernode features. incidence_1 : torch.Tensor, shape = (n_nodes, n_edges) Boundary matrix of rank 1. Returns ------- x_0 : torch.Tensor Output node features. x_1 : torch.Tensor Output hyperedge features. """ for layer in self.layers: x_0, x_1 = layer(x_0, incidence_1) x_0 = self.layer_drop(x_0) return x_0, x_1