Source code for topomodelx.base.message_passing

"""Message passing module."""
import math
from typing import Literal

import torch

from topomodelx.utils.scatter import scatter


[docs] class MessagePassing(torch.nn.Module): """Define message passing. This class defines message passing through a single neighborhood N, by decomposing it into 2 steps: 1. 🟥 Create messages going from source cells to target cells through N. 2. 🟧 Aggregate messages coming from different sources cells onto each target cell. This class should not be instantiated directly, but rather inherited through subclasses that effectively define a message passing function. This class does not have trainable weights, but its subclasses should define these weights. Parameters ---------- aggr_func : Literal["sum", "mean", "add"], default="sum" Aggregation function to use. att : bool, default=False Whether to use attention. initialization : Literal["uniform", "xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method for the weights of the layer. initialization_gain : float, default=1.414 Gain for the weight initialization. References ---------- .. [1] Hajij, Zamzmi, Papamarkou, Miolane, Guzmán-Sáenz, Ramamurthy, Birdal, Dey, Mukherjee, Samaga, Livesay, Walters, Rosen, Schaub. Topological deep learning: going beyond graph data (2023). https://arxiv.org/abs/2206.00606. .. [2] Papillon, Sanborn, Hajij, Miolane. Architectures of topological deep learning: a survey on topological neural networks (2023). https://arxiv.org/abs/2304.10031. """ def __init__( self, aggr_func: Literal["sum", "mean", "add"] = "sum", att: bool = False, initialization: Literal[ "uniform", "xavier_uniform", "xavier_normal" ] = "xavier_uniform", initialization_gain: float = 1.414, ) -> None: super().__init__() self.aggr_func = aggr_func self.att = att self.initialization = initialization self.initialization_gain = initialization_gain
[docs] def reset_parameters(self): r"""Reset learnable parameters. Notes ----- This function will be called by subclasses of MessagePassing that have trainable weights. """ match self.initialization: case "uniform": if self.weight is not None: stdv = 1.0 / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.att: stdv = 1.0 / math.sqrt(self.att_weight.size(1)) self.att_weight.data.uniform_(-stdv, stdv) case "xavier_uniform": if self.weight is not None: torch.nn.init.xavier_uniform_( self.weight, gain=self.initialization_gain ) if self.att: torch.nn.init.xavier_uniform_( self.att_weight.view(-1, 1), gain=self.initialization_gain ) case "xavier_normal": if self.weight is not None: torch.nn.init.xavier_normal_( self.weight, gain=self.initialization_gain ) if self.att: torch.nn.init.xavier_normal_( self.att_weight.view(-1, 1), gain=self.initialization_gain ) case _: raise ValueError( f"Initialization {self.initialization} not recognized." )
[docs] def message(self, x_source, x_target=None): """Construct message from source cells to target cells. 🟥 This provides a default message function to the message passing scheme. Alternatively, users can subclass MessagePassing and overwrite the message method in order to replace it with their own message mechanism. Parameters ---------- x_source : Tensor, shape = (..., n_source_cells, in_channels) Input features on source cells. Assumes that all source cells have the same rank r. x_target : Tensor, shape = (..., n_target_cells, in_channels) Input features on target cells. Assumes that all target cells have the same rank s. Optional. If not provided, x_target is assumed to be x_source, i.e. source cells send messages to themselves. Returns ------- torch.Tensor, shape = (..., n_source_cells, in_channels) Messages on source cells. """ return x_source
[docs] def attention(self, x_source, x_target=None): """Compute attention weights for messages. This provides a default attention function to the message-passing scheme. Alternatively, users can subclass MessagePassing and overwrite the attention method in order to replace it with their own attention mechanism. The implementation follows [1]_. Parameters ---------- x_source : torch.Tensor, shape = (n_source_cells, in_channels) Input features on source cells. Assumes that all source cells have the same rank r. x_target : torch.Tensor, shape = (n_target_cells, in_channels) Input features on source cells. Assumes that all source cells have the same rank r. Returns ------- torch.Tensor, shape = (n_messages, 1) Attention weights: one scalar per message between a source and a target cell. """ x_source_per_message = x_source[self.source_index_j] x_target_per_message = ( x_source[self.target_index_i] if x_target is None else x_target[self.target_index_i] ) x_source_target_per_message = torch.cat( [x_source_per_message, x_target_per_message], dim=1 ) return torch.nn.functional.elu( torch.matmul(x_source_target_per_message, self.att_weight) )
[docs] def aggregate(self, x_message): """Aggregate messages on each target cell. A target cell receives messages from several source cells. This function aggregates these messages into a single output feature per target cell. 🟧 This function corresponds to the within-neighborhood aggregation defined in [1]_ and [2]_. Parameters ---------- x_message : torch.Tensor, shape = (..., n_messages, out_channels) Features associated with each message. One message is sent from a source cell to a target cell. Returns ------- Tensor, shape = (..., n_target_cells, out_channels) Output features on target cells. Each target cell aggregates messages from several source cells. Assumes that all target cells have the same rank s. """ aggr = scatter(self.aggr_func) return aggr(x_message, self.target_index_i, 0)
[docs] def forward(self, x_source, neighborhood, x_target=None): r"""Forward pass. This implements message passing for a given neighborhood: - from source cells with input features `x_source`, - via `neighborhood` defining where messages can pass, - to target cells with input features `x_target`. In practice, this will update the features on the target cells. If not provided, x_target is assumed to be x_source, i.e. source cells send messages to themselves. The message passing is decomposed into two steps: 1. 🟥 Message: A message :math:`m_{y \rightarrow x}^{\left(r \rightarrow s\right)}` travels from a source cell :math:`y` of rank r to a target cell :math:`x` of rank s through a neighborhood of :math:`x`, denoted :math:`\mathcal{N} (x)`, via the message function :math:`M_\mathcal{N}`: .. math:: m_{y \rightarrow x}^{\left(r \rightarrow s\right)} = M_{\mathcal{N}}\left(\mathbf{h}_x^{(s)}, \mathbf{h}_y^{(r)}, \Theta \right), where: - :math:`\mathbf{h}_y^{(r)}` are input features on the source cells, called `x_source`, - :math:`\mathbf{h}_x^{(s)}` are input features on the target cells, called `x_target`, - :math:`\Theta` are optional parameters (weights) of the message passing function. Optionally, attention can be applied to the message, such that: .. math:: m_{y \rightarrow x}^{\left(r \rightarrow s\right)} \leftarrow att(\mathbf{h}_y^{(r)}, \mathbf{h}_x^{(s)}) . m_{y \rightarrow x}^{\left(r \rightarrow s\right)} 2. 🟧 Aggregation: Messages are aggregated across source cells :math:`y` belonging to the neighborhood :math:`\mathcal{N}(x)`: .. math:: m_x^{\left(r \rightarrow s\right)} = \text{AGG}_{y \in \mathcal{N}(x)} m_{y \rightarrow x}^{\left(r\rightarrow s\right)}, resulting in the within-neighborhood aggregated message :math:`m_x^{\left(r \rightarrow s\right)}`. Details can be found in [1]_ and [2]_. Parameters ---------- x_source : Tensor, shape = (..., n_source_cells, in_channels) Input features on source cells. Assumes that all source cells have the same rank r. neighborhood : torch.sparse, shape = (n_target_cells, n_source_cells) Neighborhood matrix. x_target : Tensor, shape = (..., n_target_cells, in_channels) Input features on target cells. Assumes that all target cells have the same rank s. Optional. If not provided, x_target is assumed to be x_source, i.e. source cells send messages to themselves. Returns ------- torch.Tensor, shape = (..., n_target_cells, out_channels) Output features on target cells. Assumes that all target cells have the same rank s. """ neighborhood = neighborhood.coalesce() self.target_index_i, self.source_index_j = neighborhood.indices() neighborhood_values = neighborhood.values() x_message = self.message(x_source=x_source, x_target=x_target) x_message = x_message.index_select(-2, self.source_index_j) if self.att: attention_values = self.attention(x_source=x_source, x_target=x_target) neighborhood_values = torch.multiply(neighborhood_values, attention_values) x_message = neighborhood_values.view(-1, 1) * x_message return self.aggregate(x_message)