Source code for topomodelx.base.conv

"""Convolutional layer for message passing."""
from typing import Literal

import torch
from torch.nn.parameter import Parameter

from topomodelx.base.message_passing import MessagePassing


[docs] class Conv(MessagePassing): """Message passing: steps 1, 2, and 3. Builds the message passing route given by one neighborhood matrix. Includes an option for an x-specific update function. Parameters ---------- in_channels : int Dimension of input features. out_channels : int Dimension of output features. aggr_norm : bool, default=False Whether to normalize the aggregated message by the neighborhood size. update_func : {"relu", "sigmoid"}, optional Update method to apply to message. att : bool, default=False Whether to use attention. initialization : {"xavier_uniform", "xavier_normal"}, default="xavier_uniform" Initialization method. initialization_gain : float, default=1.414 Initialization gain. with_linear_transform : bool, default=True Whether to apply a learnable linear transform. NB: if `False` in_channels has to be equal to out_channels. """ def __init__( self, in_channels, out_channels, aggr_norm: bool = False, update_func: Literal["relu", "sigmoid", None] = None, att: bool = False, initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", initialization_gain: float = 1.414, with_linear_transform: bool = True, ) -> None: super().__init__( att=att, initialization=initialization, initialization_gain=initialization_gain, ) self.in_channels = in_channels self.out_channels = out_channels self.aggr_norm = aggr_norm self.update_func = update_func self.weight = ( Parameter(torch.Tensor(self.in_channels, self.out_channels)) if with_linear_transform else None ) if not with_linear_transform and in_channels != out_channels: raise ValueError( "With `linear_trainsform=False`, in_channels has to be equal to out_channels" ) if self.att: self.att_weight = Parameter( torch.Tensor( 2 * self.in_channels, ) ) self.reset_parameters()
[docs] def update(self, x_message_on_target) -> torch.Tensor: """Update embeddings on each cell (step 4). Parameters ---------- x_message_on_target : torch.Tensor, shape = (n_target_cells, out_channels) Output features on target cells. Returns ------- torch.Tensor, shape = (n_target_cells, out_channels) Updated output features on target cells. """ if self.update_func == "sigmoid": return torch.sigmoid(x_message_on_target) if self.update_func == "relu": return torch.nn.functional.relu(x_message_on_target) return x_message_on_target
[docs] def forward(self, x_source, neighborhood, x_target=None) -> torch.Tensor: """Forward pass. This implements message passing: - from source cells with input features `x_source`, - via `neighborhood` defining where messages can pass, - to target cells with input features `x_target`. In practice, this will update the features on the target cells. If not provided, x_target is assumed to be x_source, i.e. source cells send messages to themselves. Parameters ---------- x_source : Tensor, shape = (..., n_source_cells, in_channels) Input features on source cells. Assumes that all source cells have the same rank r. neighborhood : torch.sparse, shape = (n_target_cells, n_source_cells) Neighborhood matrix. x_target : Tensor, shape = (..., n_target_cells, in_channels) Input features on target cells. Assumes that all target cells have the same rank s. Optional. If not provided, x_target is assumed to be x_source, i.e. source cells send messages to themselves. Returns ------- torch.Tensor, shape = (..., n_target_cells, out_channels) Output features on target cells. Assumes that all target cells have the same rank s. """ if self.att: neighborhood = neighborhood.coalesce() self.target_index_i, self.source_index_j = neighborhood.indices() attention_values = self.attention(x_source, x_target) neighborhood = torch.sparse_coo_tensor( indices=neighborhood.indices(), values=attention_values * neighborhood.values(), size=neighborhood.shape, ) if self.weight is not None: x_message = torch.mm(x_source, self.weight) else: x_message = x_source x_message_on_target = torch.mm(neighborhood, x_message) if self.aggr_norm: neighborhood_size = torch.sum(neighborhood.to_dense(), dim=1) x_message_on_target = torch.einsum( "i,ij->ij", 1 / neighborhood_size, x_message_on_target ) return self.update(x_message_on_target)