Source code for topomodelx.nn.cell.can_layer

"""Cell Attention Network layer."""

from collections.abc import Callable
from typing import Literal

import torch
from torch import nn, topk
from torch.nn import Linear, Parameter, init
from torch.nn import functional as F

from topomodelx.base.aggregation import Aggregation
from topomodelx.base.message_passing import MessagePassing
from topomodelx.utils.scatter import scatter_add, scatter_sum


[docs] def softmax(src, index, num_cells: int): r"""Compute the softmax of the attention coefficients. Parameters ---------- src : torch.Tensor, shape = (n_k_cells, heads) Attention coefficients. index : torch.Tensor, shape = (n_k_cells) Indices of the target nodes. num_cells : int Number of cells in the batch. Returns ------- torch.Tensor, shape = (n_k_cells, heads) Softmax of the attention coefficients. Notes ----- There should be of a default implementation of softmax in the utils file. Subtracting the maximum element in it from all elements to avoid overflow and underflow. """ src_max = src.max(dim=0, keepdim=True)[0] # (1, H) src -= src_max # (|n_k_cells|, H) src_exp = torch.exp(src) # (|n_k_cells|, H) src_sum = scatter_sum(src_exp, index, dim=0, dim_size=num_cells)[ index ] # (|n_k_cells|, H) return src_exp / (src_sum + 1e-16) # (|n_k_cells|, H)
[docs] def add_self_loops(neighborhood): """Add self-loops to the neighborhood matrix. Parameters ---------- neighborhood : torch.sparse_coo_tensor, shape = (n_k_cells, n_k_cells) Neighborhood matrix. Returns ------- torch.sparse_coo_tensor, shape = (n_k_cells, n_k_cells) Neighborhood matrix with self-loops. Notes ----- Add to utils file. """ N = neighborhood.shape[0] cell_index, cell_weight = neighborhood._indices(), neighborhood._values() # create loop index loop_index = torch.arange(0, N, dtype=torch.long, device=neighborhood.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) # add loop index to neighborhood cell_index = torch.cat([cell_index, loop_index], dim=1) cell_weight = torch.cat( [cell_weight, torch.ones(N, dtype=torch.float, device=neighborhood.device)] ) return torch.sparse_coo_tensor( indices=cell_index, values=cell_weight, size=(N, N), ).coalesce()
[docs] class LiftLayer(MessagePassing): """Attentional Lift Layer. This is adapted from the official implementation of the Cell Attention Network (CAN) [1]_. Parameters ---------- in_channels_0 : int Number of input channels of the node signal. heads : int Number of attention heads. signal_lift_activation : Callable Activation function applied to the lifted signal. signal_lift_dropout : float Dropout rate applied to the lifted signal. References ---------- .. [1] Giusti, Battiloro, Testa, Di Lorenzo, Sardellitti and Barbarossa. Cell attention networks (2022). Paper: https://arxiv.org/pdf/2209.08179.pdf Repository: https://github.com/lrnzgiusti/can """ def __init__( self, in_channels_0: int, heads: int, signal_lift_activation: Callable, signal_lift_dropout: float, ) -> None: super().__init__() self.in_channels_0 = in_channels_0 self.att_parameter = nn.Parameter(torch.empty(size=(2 * in_channels_0, heads))) self.signal_lift_activation = signal_lift_activation self.signal_lift_dropout = signal_lift_dropout
[docs] def reset_parameters(self) -> None: """Reinitialize learnable parameters using Xavier uniform initialization.""" gain = nn.init.calculate_gain("relu") nn.init.xavier_uniform_(self.att_parameter.data, gain=gain)
[docs] def message(self, x_source, x_target=None): """Construct a message from source 0-cells to target 1-cell. Parameters ---------- x_source : torch.Tensor, shape = (num_edges, in_channels_0) Node signal of the source 0-cells. x_target : torch.Tensor, shape = (num_edges, in_channels_0) Node signal of the target 1-cell. Returns ------- torch.Tensor, shape = (num_edges, heads) Edge signal. """ # Concatenate source and target node feature vectors node_features_stacked = torch.cat((x_source, x_target), dim=1) # Compute the output edge signal by applying the activation function edge_signal = torch.einsum( "ij,jh->ih", node_features_stacked, self.att_parameter ) # (num_edges, heads) return self.signal_lift_activation(edge_signal)
[docs] def forward(self, x_0, adjacency_0) -> torch.Tensor: # type: ignore[override] """Forward pass. Parameters ---------- x_0 : torch.Tensor, shape = (num_nodes, in_channels_0) Node signal. adjacency_0 : torch.Tensor, shape = (num_nodes, num_nodes) Sparse neighborhood matrix. Returns ------- torch.Tensor, shape = (num_edges, 1) Edge signal. """ # Extract source and target nodes from the graph's edge index source, target = adjacency_0.indices() # (num_edges,) # Extract the node signal of the source and target nodes x_source = x_0[source] # (num_edges, in_channels_0) x_target = x_0[target] # (num_edges, in_channels_0) # Compute the edge signal return self.message(x_source, x_target) # (num_edges, 1)
[docs] class MultiHeadLiftLayer(nn.Module): r"""Multi Head Attentional Lift Layer. Multi Head Attentional Lift Layer adapted from the official implementation of the Cell Attention Network (CAN) [1]_. Parameters ---------- in_channels_0 : int Number of input channels. heads : int, optional Number of attention heads. signal_lift_activation : Callable, optional Activation function to apply to the output edge signal. signal_lift_dropout : float, optional Dropout rate to apply to the output edge signal. signal_lift_readout : str, optional Readout method to apply to the output edge signal. """ def __init__( self, in_channels_0: int, heads: int = 1, signal_lift_activation: Callable = torch.relu, signal_lift_dropout: float = 0.0, signal_lift_readout: str = "cat", ) -> None: super().__init__() assert heads > 0, ValueError("Number of heads must be > 0") assert signal_lift_readout in [ "cat", "sum", "avg", "max", ], "Invalid readout method." self.in_channels_0 = in_channels_0 self.heads = heads self.signal_lift_readout = signal_lift_readout self.signal_lift_dropout = signal_lift_dropout self.signal_lift_activation = signal_lift_activation self.lifts = LiftLayer( in_channels_0=in_channels_0, heads=heads, signal_lift_activation=signal_lift_activation, signal_lift_dropout=signal_lift_dropout, ) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reinitialize learnable parameters using Xavier uniform initialization.""" self.lifts.reset_parameters()
[docs] def forward(self, x_0, adjacency_0, x_1=None) -> torch.Tensor: r"""Forward pass. Parameters ---------- x_0 : torch.Tensor, shape = (num_nodes, in_channels_0) Node signal. adjacency_0 : torch.Tensor, shape = (2, num_edges) Edge index. x_1 : torch.Tensor, shape = (num_edges, in_channels_1), optional Edge signal. Returns ------- torch.Tensor, shape = (num_edges, heads + in_channels_1) Lifted node signal. Notes ----- .. math:: \begin{align*} &🟥 \quad m_{(y,z) \rightarrow x}^{(0 \rightarrow 1)} = \alpha(h_y, h_z) = \Theta(h_z||h_y)\\ &🟦 \quad h_x^{(1)} = \phi(h_x, m_x^{(1)}) \end{align*} """ # Lift the node signal for each attention head attention_heads_x_1 = self.lifts(x_0, adjacency_0) # Combine the output edge signals using the specified readout strategy readout_methods = { "cat": lambda x: x, "sum": lambda x: x.sum(dim=1)[:, None], "avg": lambda x: x.mean(dim=1)[:, None], "max": lambda x: x.max(dim=1).values[:, None], } combined_x_1 = readout_methods[self.signal_lift_readout](attention_heads_x_1) # Apply dropout to the combined edge signal combined_x_1 = F.dropout( combined_x_1, self.signal_lift_dropout, training=self.training ) # Concatenate the lifted node signal with the original node signal if is not None if x_1 is not None: combined_x_1 = torch.cat( (combined_x_1, x_1), dim=1 ) # (num_edges, heads + in_channels_1) return combined_x_1
[docs] class PoolLayer(MessagePassing): r"""Attentional Pooling Layer. Attentional Pooling Layer adapted from the official implementation of the Cell Attention Network (CAN) [1]_. Parameters ---------- k_pool : float in (0, 1] The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation. in_channels_0 : int Number of input channels of the input signal. signal_pool_activation : Callable Activation function applied to the pooled signal. readout : bool, optional Whether to apply a readout operation to the pooled signal. """ def __init__( self, k_pool: float, in_channels_0: int, signal_pool_activation: Callable, readout: bool = True, ) -> None: super().__init__() self.k_pool = k_pool self.in_channels_0 = in_channels_0 self.readout = readout # Learnable attention parameter for the pooling operation self.att_pool = nn.Parameter(torch.empty(size=(in_channels_0, 1))) self.signal_pool_activation = signal_pool_activation # Initialize the attention parameter using Xavier initialization self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reinitialize learnable parameters using Xavier uniform initialization.""" gain = init.calculate_gain("relu") init.xavier_uniform_(self.att_pool.data, gain=gain)
[docs] def forward( # type: ignore[override] self, x, down_laplacian_1, up_laplacian_1 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass. Parameters ---------- x : torch.Tensor, shape = (n_r_cells, in_channels_r) Input r-cell signal. down_laplacian_1 : torch.Tensor Lower neighborhood matrix. up_laplacian_1 : torch.Tensor Upper neighbourhood matrix. Returns ------- torch.Tensor Pooled r_cell signal of shape (n_r_cells, in_channels_r). Notes ----- .. math:: \begin{align*} &🟥 \quad m_{x}^{(r)} = \gamma^t(h_x^t) = \tau^t (a^t\cdot h_x^t)\\ &🟦 \quad h_x^{t+1,(r)} = \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1} \end{align*} """ # Compute the output r-cell signal by applying the activation function Zp = torch.einsum("nc,ce->ne", x, self.att_pool) # Apply top-k pooling to the r-cell signal _, top_indices = topk(Zp.view(-1), int(self.k_pool * Zp.size(0))) # Rescale the pooled signal Zp = self.signal_pool_activation(Zp) out = x[top_indices] * Zp[top_indices] # Readout operation if self.readout: out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices] # Update lower and upper neighborhood matrices with the top-k pooled r-cells down_laplacian_1_modified = torch.index_select(down_laplacian_1, 0, top_indices) down_laplacian_1_modified = torch.index_select( down_laplacian_1_modified, 1, top_indices ) up_laplacian_1_modified = torch.index_select(up_laplacian_1, 0, top_indices) up_laplacian_1_modified = torch.index_select( up_laplacian_1_modified, 1, top_indices ) # return sparse matrices of neighborhood return ( out, down_laplacian_1_modified.to_sparse().float().coalesce(), up_laplacian_1_modified.to_sparse().float().coalesce(), )
[docs] class MultiHeadCellAttention(MessagePassing): """Attentional Message Passing v1. Attentional Message Passing from Cell Attention Network (CAN) [1]_ following the attention mechanism proposed in GAT [2]_. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. dropout : float Dropout rate applied to the output signal. heads : int Number of attention heads. concat : bool Whether to concatenate the output of each attention head. att_activation : Callable Activation function to use for the attention weights. add_self_loops : bool, optional Whether to add self-loops to the adjacency matrix. aggr_func : Literal["sum", "mean", "add"], default="sum" Aggregation function to use. initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method for the weights of the layer. Notes ----- If there are no non-zero values in the neighborhood, then the neighborhood is empty and forward returns zeros Tensor. References ---------- .. [2] Veličković, Cucurull, Casanova, Romero, Liò and Bengio. Graph attention networks (2017). https://arxiv.org/pdf/1710.10903.pdf """ def __init__( self, in_channels: int, out_channels: int, dropout: float, heads: int, concat: bool, att_activation: torch.nn.Module, add_self_loops: bool = False, aggr_func: Literal["sum", "mean", "add"] = "sum", initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", ) -> None: super().__init__(aggr_func=aggr_func, att=True, initialization=initialization) self.in_channels = in_channels self.out_channels = out_channels self.att_activation = att_activation self.heads = heads self.concat = concat self.dropout = dropout self.add_self_loops = add_self_loops self.lin = torch.nn.Linear(in_channels, out_channels, bias=False) self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels // heads)) self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels // heads)) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset the layer parameters.""" torch.nn.init.xavier_uniform_(self.att_weight_src) torch.nn.init.xavier_uniform_(self.att_weight_dst) self.lin.reset_parameters()
[docs] def message(self, x_source): """Construct message from source cells to target cells. 🟥 This provides a default message function to the message passing scheme. Parameters ---------- x_source : torch.Tensor, shape = (n_k_cells, channels) Input features on the r-cell of the cell complex. Returns ------- torch.Tensor, shape = (n_k_cells, heads, in_channels) Messages on source cells. """ # Compute the linear transformation on the source features x_message = self.lin(x_source).view( -1, self.heads, self.out_channels // self.heads ) # (n_k_cells, H, C) # compute the source and target messages x_source_per_message = x_message[self.source_index_j] # (|n_k_cells|, H, C) x_target_per_message = x_message[self.target_index_i] # (|n_k_cells|, H, C) # compute the attention coefficients alpha = self.attention( x_source_per_message, x_target_per_message ) # (|n_k_cells|, H) # for each head, Aggregate the messages return x_source_per_message * alpha[:, :, None] # (|n_k_cells|, H, C)
[docs] def attention(self, x_source, x_target): """Compute attention weights for messages. Parameters ---------- x_source : torch.Tensor, shape = [n_k_cells, in_channels] Source node features. x_target : torch.Tensor, shape = [n_k_cells, in_channels] Target node features. Returns ------- torch.Tensor, shape = [n_k_cells, heads] Attention weights. """ # Compute attention coefficients alpha_src = torch.einsum( "ijk,tjk->ij", x_source, self.att_weight_src ) # (|n_k_cells|, H) alpha_dst = torch.einsum( "ijk,tjk->ij", x_target, self.att_weight_dst ) # (|n_k_cells|, H) alpha = alpha_src + alpha_dst # Apply activation function alpha = self.att_activation(alpha) # Normalize the attention coefficients alpha = softmax(alpha, self.target_index_i, x_source.shape[0]) # Apply dropout return F.dropout(alpha, p=self.dropout, training=self.training)
[docs] def forward(self, x_source, neighborhood): """Forward pass. Parameters ---------- x_source : torch.Tensor, shape = (n_k_cells, channels) Input features on the r-cell of the cell complex. neighborhood : torch.sparse, shape = (n_k_cells, n_k_cells) Neighborhood matrix mapping r-cells to r-cells (A_k). Returns ------- torch.Tensor, shape = (n_k_cells, channels) Output features on the r-cell of the cell complex. """ # If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor if not neighborhood.values().nonzero().size(0) > 0 and self.concat: return torch.zeros( (x_source.shape[0], self.out_channels), device=x_source.device, ) # (n_k_cells, H * C) if not neighborhood.values().nonzero().size(0) > 0 and not self.concat: return torch.zeros( (x_source.shape[0], self.out_channels // self.heads), device=x_source.device, ) # (n_k_cells, C) # Add self-loops to the neighborhood matrix if necessary if self.add_self_loops: neighborhood = add_self_loops(neighborhood) # returns the indices of the non-zero values in the neighborhood matrix ( self.target_index_i, self.source_index_j, ) = neighborhood.indices() # (|n_k_cells|, 1), (|n_k_cells|, 1) # compute message passing step message = self.message(x_source) # (|n_k_cells|, H, C) # compute within-neighborhood aggregation step aggregated_message = self.aggregate(message) # (n_k_cells, H, C) # if concat true, concatenate the messages for each head. Otherwise, average the messages for each head. if self.concat: return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C) return aggregated_message.mean(dim=1) # (n_k_cells, C)
[docs] class MultiHeadCellAttention_v2(MessagePassing): """Attentional Message Passing v2. Attentional Message Passing from Cell Attention Network (CAN) [1]_ following the attention mechanism proposed in GATv2 [3]_ Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. dropout : float Dropout rate applied to the output signal. heads : int Number of attention heads. concat : bool Whether to concatenate the output of each attention head. att_activation : Callable Activation function to use for the attention weights. add_self_loops : bool, optional Whether to add self-loops to the adjacency matrix. aggr_func : Literal["sum", "mean", "add"], default="sum" Aggregation function to use. initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method for the weights of the layer. share_weights : bool, optional Whether to share the weights between the attention heads. Notes ----- If there are no non-zero values in the neighborhood, then the neighborhood is empty. References ---------- .. [3] Brody, Alon, Yahav. How attentive are graph attention networks? (2022). https://arxiv.org/pdf/2105.14491.pdf """ def __init__( self, in_channels: int, out_channels: int, dropout: float, heads: int, concat: bool, att_activation: torch.nn.Module, add_self_loops: bool = True, aggr_func: Literal["sum", "mean", "add"] = "sum", initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", share_weights: bool = False, ) -> None: super().__init__( aggr_func=aggr_func, att=True, initialization=initialization, ) self.in_channels = in_channels self.out_channels = out_channels self.att_activation = att_activation self.heads = heads self.concat = concat self.dropout = dropout self.add_self_loops = add_self_loops if share_weights: self.lin_src = self.lin_dst = torch.nn.Linear( in_channels, out_channels, bias=False ) else: self.lin_src = torch.nn.Linear(in_channels, out_channels, bias=False) self.lin_dst = torch.nn.Linear(in_channels, out_channels, bias=False) self.att_weight = Parameter(torch.Tensor(1, heads, out_channels // heads)) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset the layer parameters.""" torch.nn.init.xavier_uniform_(self.att_weight) self.lin_src.reset_parameters() self.lin_dst.reset_parameters()
[docs] def message(self, x_source): """Construct message from source cells to target cells. 🟥 This provides a default message function to the message passing scheme. Parameters ---------- x_source : torch.Tensor, shape = (n_k_cells, channels) Input features on the r-cell of the cell complex. Returns ------- Tensor, shape = (n_k_cells, heads, in_channels) Messages on source cells. """ # Compute the linear transformation on the source features x_src_message = self.lin_src(x_source).view( -1, self.heads, self.out_channels // self.heads ) # (n_k_cells, H, C) # Compute the linear transformation on the source features x_dst_message = self.lin_dst(x_source).view( -1, self.heads, self.out_channels // self.heads ) # (n_k_cells, H, C) # Get the source and target projections of the neighborhood x_source_per_message = x_src_message[self.source_index_j] # (|n_k_cells|, H, C) x_target_per_message = x_dst_message[self.target_index_i] # (|n_k_cells|, H, C) # concatenate the source and target projections of the neighborhood x_message = x_source_per_message + x_target_per_message # (|n_k_cells|, H, C) # Compute the attention coefficients alpha = self.attention(x_message) # (|n_k_cells|, H) # for each head, Aggregate the messages return x_source_per_message * alpha[:, :, None] # (|n_k_cells|, H, C)
[docs] def attention(self, x_source): """Compute attention weights for messages. Parameters ---------- x_source : torch.Tensor, shape = (|n_k_cells|, heads, in_channels) Source node features. Returns ------- torch.Tensor, shape = (n_k_cells, heads) Attention weights. """ # Apply activation function x_source = self.att_activation(x_source) # (|n_k_cells|, H, C) # Compute attention coefficients alpha = torch.einsum( "ijk,tjk->ij", x_source, self.att_weight ) # (|n_k_cells|, H) # Normalize the attention coefficients alpha = softmax(alpha, self.target_index_i, x_source.shape[0]) # Apply dropout return F.dropout(alpha, p=self.dropout, training=self.training)
[docs] def forward(self, x_source, neighborhood): """Forward pass. Parameters ---------- x_source : torch.Tensor, shape = (n_k_cells, channels) Input features on the r-cell of the cell complex. neighborhood : torch.sparse, shape = (n_k_cells, n_k_cells) Neighborhood matrix mapping r-cells to r-cells (A_k), [up, down]. Returns ------- torch.Tensor, shape = (n_k_cells, channels) Output features on the r-cell of the cell complex. """ # If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor if not neighborhood.values().nonzero().size(0) > 0 and self.concat: return torch.zeros( (x_source.shape[0], self.out_channels), device=x_source.device, ) # (n_k_cells, H * C) if not neighborhood.values().nonzero().size(0) > 0 and not self.concat: return torch.zeros( (x_source.shape[0], self.out_channels // self.heads), device=x_source.device, ) # (n_k_cells, C) # Add self-loops to the neighborhood matrix if necessary if self.add_self_loops: neighborhood = add_self_loops(neighborhood) # Get the source and target indices of the neighborhood ( self.target_index_i, self.source_index_j, ) = neighborhood.indices() # (|n_k_cells|, 1), (|n_k_cells|, 1) # compute message passing step message = self.message(x_source) # (|n_k_cells|, H, C) # compute within-neighborhood aggregation step aggregated_message = self.aggregate(message) # (n_k_cells, H, C) # if concat true, concatenate the messages for each head. Otherwise, average the messages for each head. if self.concat: return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C) return aggregated_message.mean(dim=1) # (n_k_cells, C)
[docs] class CANLayer(torch.nn.Module): r"""Layer of the Cell Attention Network (CAN) model. The CAN layer considers an attention convolutional message passing though the upper and lower neighborhoods of the cell. Additionally, a skip connection can be added to the output of the layer. Parameters ---------- in_channels : int Dimension of input features on n-cells. out_channels : int Dimension of output. heads : int, default=1 Number of attention heads. dropout : float, optional Dropout probability of the normalized attention coefficients. concat : bool, default=True If True, the output of each head is concatenated. Otherwise, the output of each head is averaged. skip_connection : bool, default=True If True, skip connection is added. att_activation : Callable, default=torch.nn.LeakyReLU() Activation function applied to the attention coefficients. add_self_loops : bool, optional If True, self-loops are added to the neighborhood matrix. aggr_func : Literal["mean", "sum"], default="sum" Between-neighborhood aggregation function applied to the messages. update_func : Literal["relu", "sigmoid", "tanh", None], default="relu" Update function applied to the messages. version : Literal["v1", "v2"], default="v1" Version of the layer, by default "v1" which is the same as the original CAN layer. While "v2" has the same attetion mechanism as the GATv2 layer. share_weights : bool, default=False This option is valid only for "v2". If True, the weights of the linear transformation applied to the source and target features are shared, by default False. **kwargs : optional Additional arguments of CAN layer. Notes ----- Add_self_loops is preferred to be False. If necessary, the self-loops should be added to the neighborhood matrix in the preprocessing step. """ lower_att: MultiHeadCellAttention | MultiHeadCellAttention_v2 upper_att: MultiHeadCellAttention | MultiHeadCellAttention_v2 def __init__( self, in_channels: int, out_channels: int, heads: int = 1, dropout: float = 0.0, concat: bool = True, skip_connection: bool = True, att_activation: torch.nn.Module | None = None, add_self_loops: bool = True, aggr_func: Literal["mean", "sum"] = "sum", update_func: Literal["relu", "sigmoid", "tanh"] | None = "relu", version: Literal["v1", "v2"] = "v1", share_weights: bool = False, **kwargs, ) -> None: super().__init__() if att_activation is None: att_activation = torch.nn.LeakyReLU() assert in_channels > 0, ValueError("Number of input channels must be > 0") assert out_channels > 0, ValueError("Number of output channels must be > 0") assert heads > 0, ValueError("Number of heads must be > 0") assert out_channels % heads == 0, ValueError( "Number of output channels must be divisible by the number of heads" ) assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]") # assert that shared weight is True only if version is v2 assert share_weights is False or version == "v2", ValueError( "Shared weights is valid only for v2" ) if version == "v1": # lower attention self.lower_att = MultiHeadCellAttention( in_channels=in_channels, out_channels=out_channels, add_self_loops=add_self_loops, dropout=dropout, heads=heads, att_activation=att_activation, concat=concat, ) # upper attention self.upper_att = MultiHeadCellAttention( in_channels=in_channels, out_channels=out_channels, add_self_loops=add_self_loops, dropout=dropout, heads=heads, att_activation=att_activation, concat=concat, ) elif version == "v2": # lower attention self.lower_att = MultiHeadCellAttention_v2( in_channels=in_channels, out_channels=out_channels, add_self_loops=add_self_loops, dropout=dropout, heads=heads, att_activation=att_activation, concat=concat, share_weights=share_weights, ) # upper attention self.upper_att = MultiHeadCellAttention_v2( in_channels=in_channels, out_channels=out_channels, add_self_loops=add_self_loops, dropout=dropout, heads=heads, att_activation=att_activation, concat=concat, share_weights=share_weights, ) # linear transformation if skip_connection: out_channels = out_channels if concat else out_channels // heads self.lin = Linear(in_channels, out_channels, bias=False) self.eps = 1 + 1e-6 # between-neighborhood aggregation and update self.aggregation = Aggregation(aggr_func=aggr_func, update_func=update_func) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset the parameters of the layer.""" self.lower_att.reset_parameters() self.upper_att.reset_parameters() if hasattr(self, "lin"): self.lin.reset_parameters()
[docs] def forward(self, x, down_laplacian_1, up_laplacian_1) -> torch.Tensor: r"""Forward pass. Parameters ---------- x : torch.Tensor, shape = (n_k_cells, channels) Input features on the r-cell of the cell complex. down_laplacian_1 : torch.sparse, shape = (n_k_cells, n_k_cells) Lower neighborhood matrix mapping r-cells to r-cells (A_k_low). up_laplacian_1 : torch.sparse, shape = (n_k_cells, n_k_cells) Upper neighborhood matrix mapping r-cells to r-cells (A_k_up). Returns ------- torch.Tensor, shape = (n_k_cells, out_channels) Output features on the r-cell of the cell complex. Notes ----- .. math:: \mathcal N = \{\mathcal N_1, \mathcal N_2\} = \{A_{\uparrow, r}, A_{\downarrow, r}\} .. math:: \begin{align*} &🟥 \quad m_{(y \rightarrow x),k}^{(r)} = \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k\\ &🟧 \quad m_{x,k}^{(r)} = \bigoplus_{y \in \mathcal{N}_k(x)} m^{(r)} _{(y \rightarrow x),k}\\ &🟩 \quad m_{x}^{(r)} = \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}^{(r)}\\ &🟦 \quad h_x^{t+1,(r)} = \phi^{t}(h_x^t, m_{x}^{(r)}) \end{align*} """ # message and within-neighborhood aggregation lower_x = self.lower_att(x, down_laplacian_1) upper_x = self.upper_att(x, up_laplacian_1) # skip connection if hasattr(self, "lin"): w_x = self.lin(x) * self.eps # between-neighborhood aggregation and update return ( self.aggregation([lower_x, upper_x, w_x]) if hasattr(self, "lin") else self.aggregation([lower_x, upper_x]) )