Source code for topomodelx.nn.hypergraph.hmpnn_layer

"""HMPNN (Hypergraph Message Passing Neural Network) Layer introduced in Heydari et Livi 2022."""
from typing import Literal

import torch
from torch import nn
from torch.nn import functional as F

from topomodelx.base.message_passing import MessagePassing
from topomodelx.utils.scatter import scatter


class _AdjacencyDropoutMixin:
    r"""Mixin class for applying dropout to adjacency matrices."""

    training: bool

    def apply_dropout(self, neighborhood, dropout_rate: float):
        r"""Apply dropout to the adjacency matrix.

        Parameters
        ----------
        neighborhood : torch.sparse.Tensor
            Sparse tensor representing the adjacency matrix.
        dropout_rate : float
            Dropout rate.

        Returns
        -------
        torch.sparse.Tensor
            Sparse tensor with dropout applied.
        """
        neighborhood = neighborhood.coalesce()
        return torch.sparse_coo_tensor(
            neighborhood.indices(),
            F.dropout(
                neighborhood.values().to(torch.float), dropout_rate, self.training
            ),
            neighborhood.size(),
        ).coalesce()


class _NodeToHyperedgeMessenger(MessagePassing, _AdjacencyDropoutMixin):
    r"""Node to Hyperedge Messenger class.

    Parameters
    ----------
    messaging_func : callable
        Function for messaging from nodes to hyperedges.
    adjacency_dropout : float, default = 0.7
        Dropout rate for the adjacency matrix.
    aggr_func : Literal["sum", "mean", "add"], default="sum"
        Message aggregation function.
    """

    def __init__(
        self,
        messaging_func,
        adjacency_dropout: float = 0.7,
        aggr_func: Literal["sum", "mean", "add"] = "sum",
    ) -> None:
        super().__init__(aggr_func)
        self.messaging_func = messaging_func
        self.adjacency_dropout = adjacency_dropout

    def message(self, x_source):
        r"""Message function.

        Parameters
        ----------
        x_source : torch.Tensor
            Source node features.

        Returns
        -------
        torch.Tensor
            Message passed from the source node to the hyperedge.
        """
        return self.messaging_func(x_source)

    def forward(self, x_source, neighborhood):
        r"""Forward computation.

        Parameters
        ----------
        x_source : torch.Tensor
            Source node features.
        neighborhood : torch.sparse.Tensor
            Sparse tensor representing the adjacency matrix.

        Returns
        -------
        x_message_aggregated : torch.Tensor
            Aggregated messages passed from the nodes to the hyperedge.
        x_message : torch.Tensor
            Messages passed from the nodes to the hyperedge.
        """
        neighborhood = self.apply_dropout(neighborhood, self.adjacency_dropout)
        source_index_j, self.target_index_i = neighborhood.indices()

        x_message = self.message(x_source)
        x_message_aggregated = self.aggregate(
            x_message.index_select(-2, source_index_j)
        )
        return x_message_aggregated, x_message


class _HyperedgeToNodeMessenger(MessagePassing, _AdjacencyDropoutMixin):
    r"""Hyperedge to Node Messenger class.

    Parameters
    ----------
    messaging_func : callable
        Function for messaging from hyperedges to nodes.
    adjacency_dropout : float, default = 0.7
        Dropout rate for the adjacency matrix.
    aggr_func : Literal["sum", "mean", "add"], default="sum"
        Message aggregation function.
    """

    def __init__(
        self,
        messaging_func,
        adjacency_dropout: float = 0.7,
        aggr_func: Literal["sum", "mean", "add"] = "sum",
    ) -> None:
        super().__init__(aggr_func)
        self.messaging_func = messaging_func
        self.adjacency_dropout = adjacency_dropout

    def message(self, x_source, neighborhood, node_messages):
        r"""Message function.

        Parameters
        ----------
        x_source : torch.Tensor
            Source hyperedge features.
        neighborhood : torch.sparse.Tensor
            Sparse tensor representing the adjacency matrix.
        node_messages : torch.Tensor
            Messages passed from the nodes to the hyperedge.

        Returns
        -------
        torch.Tensor
            Message passed from the hyperedge to the nodes.
        """
        hyperedge_neighborhood = self.apply_dropout(
            neighborhood, self.adjacency_dropout
        )
        source_index_j, target_index_i = hyperedge_neighborhood.indices()
        node_messages_aggregated = scatter(self.aggr_func)(
            node_messages.index_select(-2, source_index_j), target_index_i, 0
        )

        return self.messaging_func(x_source, node_messages_aggregated)

    def forward(self, x_source, neighborhood, node_messages):
        r"""Forward computation.

        Parameters
        ----------
        x_source : torch.Tensor
            Source hyperedge features.
        neighborhood : torch.sparse.Tensor
            Sparse tensor representing the adjacency matrix.
        node_messages : torch.Tensor
            Messages passed from the nodes to the hyperedge.

        Returns
        -------
        torch.Tensor
            Aggregated messages passed from the hyperedge to the nodes.
        """
        x_message = self.message(x_source, neighborhood, node_messages)

        neighborhood = self.apply_dropout(neighborhood, self.adjacency_dropout)
        self.target_index_i, source_index_j = neighborhood.indices()

        return self.aggregate(x_message.index_select(-2, source_index_j))


class _DefaultHyperedgeToNodeMessagingFunc(nn.Module):
    r"""Default hyperedge to node messaging function.

    Parameters
    ----------
    in_channels : int
        Dimension of the input features.
    """

    def __init__(self, in_channels) -> None:
        super().__init__()
        self.linear = nn.Linear(2 * in_channels, in_channels)

    def forward(self, x_1, m_0):
        r"""Forward computation.

        Parameters
        ----------
        x_1 : torch.Tensor
            Input hyperedge features.
        m_0 : torch.Tensor
            Aggregated messages from the nodes.

        Returns
        -------
        torch.Tensor
            Messages passed from the hyperedge to the nodes.
        """
        return F.sigmoid(self.linear(torch.cat((x_1, m_0), dim=1)))


class _DefaultUpdatingFunc(nn.Module):
    r"""Default updating function.

    Parameters
    ----------
    in_channels : int
        Dimension of the input features.
    """

    def __init__(self, in_channels) -> None:
        super().__init__()

    def forward(self, x, m):
        r"""Forward computation.

        Parameters
        ----------
        x : torch.Tensor
            Input features.
        m : torch.Tensor
            Messages passed from the neighbors.

        Returns
        -------
        torch.Tensor
            Updated features.
        """
        return F.sigmoid(x + m)


[docs] class HMPNNLayer(nn.Module): r"""HMPNN Layer [1]_. The layer is a hypergraph comprised of nodes and hyperedges that makes their new reprsentation using the input representation and the messages passed between them. In this layer, the message passed from a node to its neighboring hyperedges is only a function of its input representation, but the message from a hyperedge to its neighboring nodes is also a function of the messages recieved from them beforehand. This way, a node could have a more explicit effect on its upper adjacent neighbors i.e. the nodes that it share a hyperedge with. .. math:: \begin{align*} &🟥 \quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)} = M_\mathcal{C} (h_y^{t,(0)}, h_z^{t, (1)})\\ &🟧 \quad m_{z'}^{(0 \rightarrow 1)} = AGG'{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0\rightarrow1)}\\ &🟧 \quad m_{z}^{(0 \rightarrow 1)} = AGG_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)}\\ &🟥 \quad m_{z \rightarrow x}^{(1 \rightarrow0)} = M_\mathcal{B}(h_z^{t,(1)}, m_z^{(1)})\\ &🟧 \quad m_x^{(1 \rightarrow0)} = AGG_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow0)}\\ &🟩 \quad m_x^{(0)} = m_x^{(1 \rightarrow 0)}\\ &🟩 \quad m_z^{(1)} = m_{z'}^{(0 \rightarrow 1)}\\ &🟦 \quad h_x^{t+1, (0)} = U^{(0)}(h_x^{t,(0)}, m_x^{(0)})\\ &🟦 \quad h_z^{t+1,(1)} = U^{(1)}(h_z^{t,(1)}, m_{z}^{(1)}) \end{align*} Parameters ---------- in_channels : int Dimension of input features. node_to_hyperedge_messaging_func : None Node messaging function as a callable or nn.Module object. If not given, a linear plus sigmoid function is used, according to the paper. hyperedge_to_node_messaging_func : None Hyperedge messaging function as a callable or nn.Module object. It gets hyperedge input features and aggregated messages of nodes as input and returns hyperedge messages. If not given, two inputs are concatenated and a linear layer reducing back to in_channels plus sigmoid is applied, according to the paper. adjacency_dropout : int, default = 0.7 Adjacency dropout rate. aggr_func : Literal["sum", "mean", "add"], default="sum" Message aggregation function. updating_dropout : int, default = 0.5 Regular dropout rate applied to node and hyperedge features. updating_func : callable or None, default = None The final function or nn.Module object to be called on node and hyperedge features to retrieve their new representation. If not given, a linear layer is applied, received message is added and sigmoid is called. **kwargs : optional Additional arguments for the layer modules. References ---------- .. [1] Heydari S, Livi L. Message passing neural networks for hypergraphs. ICANN 2022. https://arxiv.org/abs/2203.16995 """ def __init__( self, in_channels, node_to_hyperedge_messaging_func=None, hyperedge_to_node_messaging_func=None, adjacency_dropout: float = 0.7, aggr_func: Literal["sum", "mean", "add"] = "sum", updating_dropout: float = 0.5, updating_func=None, **kwargs, ) -> None: super().__init__() if node_to_hyperedge_messaging_func is None: node_to_hyperedge_messaging_func = nn.Sequential( nn.Linear(in_channels, in_channels), nn.Sigmoid() ) self.node_to_hyperedge_messenger = _NodeToHyperedgeMessenger( node_to_hyperedge_messaging_func, adjacency_dropout, aggr_func ) if hyperedge_to_node_messaging_func is None: hyperedge_to_node_messaging_func = _DefaultHyperedgeToNodeMessagingFunc( in_channels ) self.hyperedge_to_node_messenger = _HyperedgeToNodeMessenger( hyperedge_to_node_messaging_func, adjacency_dropout, aggr_func ) self.node_batchnorm = nn.BatchNorm1d(in_channels) self.hyperedge_batchnorm = nn.BatchNorm1d(in_channels) self.dropout = torch.distributions.Bernoulli(updating_dropout) if updating_func is None: updating_func = _DefaultUpdatingFunc(in_channels) self.updating_func = updating_func
[docs] def apply_regular_dropout(self, x): """Apply regular dropout according to the paper. Unmasked features in a vector are scaled by d+k / d in which k is the number of masked features in the vector and d is the total number of features. Parameters ---------- x : torch.Tensor Input features. Returns ------- torch.Tensor Output features. """ if self.training: mask = self.dropout.sample(x.shape).to(dtype=torch.float, device=x.device) d = x.size(0) x *= mask * (2 * d - mask.sum(dim=1)).view(-1, 1) / d return x
[docs] def forward(self, x_0, x_1, incidence_1): r"""Forward computation. Parameters ---------- x_0 : torch.Tensor, shape = (n_nodes, node_in_channels) Input features of the nodes. x_1 : torch.Tensor, shape = (n_edges, hyperedge_in_channels) Input features of the hyperedges. incidence_1 : torch.sparse.Tensor, shape = (n_nodes, n_edges) Incidence matrix mapping hyperedges to nodes (B_1). Returns ------- x_0 : torch.Tensor, shape = (n_nodes, node_in_channels) Output features of the nodes. x_1 : torch.Tensor, shape = (n_edges, hyperedge_in_channels) Output features of the hyperedges. """ node_messages_aggregated, node_messages = self.node_to_hyperedge_messenger( x_0, incidence_1 ) hyperedge_messages_aggregated = self.hyperedge_to_node_messenger( x_1, incidence_1, node_messages ) x_0 = self.updating_func( self.apply_regular_dropout(self.node_batchnorm(x_0)), hyperedge_messages_aggregated, ) x_1 = self.updating_func( self.apply_regular_dropout(self.hyperedge_batchnorm(x_1)), node_messages_aggregated, ) return x_0, x_1