"""CAN class."""
import torch
import torch.nn.functional as F
from topomodelx.nn.cell.can_layer import CANLayer, MultiHeadLiftLayer, PoolLayer
[docs]
class CAN(torch.nn.Module):
"""CAN (Cell Attention Network) [1]_ module for graph classification.
Parameters
----------
in_channels_0 : int
Number of input channels for the node-level input.
in_channels_1 : int
Number of input channels for the edge-level input.
out_channels : int
Number of output channels.
dropout : float, optional
Dropout probability. Default is 0.5.
heads : int, optional
Number of attention heads. Default is 2.
concat : bool, optional
Whether to concatenate the output channels of attention heads. Default is True.
skip_connection : bool, optional
Whether to use skip connections. Default is True.
att_activation : torch.nn.Module, optional
Activation function for attention mechanism. Default is torch.nn.LeakyReLU(0.2).
n_layers : int, default=2
Number of CAN layers.
att_lift : bool, default=True
Whether to apply a lift the signal from node-level to edge-level input.
pooling : bool, default=False
Whether to apply pooling operation.
k_pool : float, default=0.5
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.
**kwargs : optional
Additional arguments CANLayer.
References
----------
.. [1] Giusti, Battiloro, Testa, Di Lorenzo, Sardellitti and Barbarossa.
Cell attention networks (2022).
Paper: https://arxiv.org/pdf/2209.08179.pdf
Repository: https://github.com/lrnzgiusti/can
"""
def __init__(
self,
in_channels_0,
in_channels_1,
out_channels,
dropout=0.5,
heads=2,
concat=True,
skip_connection=True,
att_activation=None,
n_layers=2,
att_lift=True,
pooling=False,
k_pool=0.5,
**kwargs,
):
super().__init__()
if att_activation is None:
att_activation = torch.nn.LeakyReLU(0.2)
if att_lift:
self.lift_layer = MultiHeadLiftLayer(
in_channels_0=in_channels_0,
heads=in_channels_0,
signal_lift_dropout=0.5,
)
in_channels_1 = in_channels_1 + in_channels_0
layers = []
layers.append(
CANLayer(
in_channels=in_channels_1,
out_channels=out_channels,
heads=heads,
concat=concat,
skip_connection=skip_connection,
att_activation=att_activation,
aggr_func="sum",
update_func="relu",
**kwargs,
)
)
for _ in range(n_layers - 1):
layers.append(
CANLayer(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
heads=heads,
concat=concat,
skip_connection=skip_connection,
att_activation=att_activation,
aggr_func="sum",
update_func="relu",
**kwargs,
)
)
if pooling:
layers.append(
PoolLayer(
k_pool=k_pool,
in_channels_0=out_channels,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
**kwargs,
)
)
self.layers = torch.nn.ModuleList(layers)
[docs]
def forward(self, x_0, x_1, adjacency_0, down_laplacian_1, up_laplacian_1):
"""Forward pass.
Parameters
----------
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Input features on the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Input features on the edges (1-cells).
adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes)
Neighborhood matrix from nodes to nodes.
down_laplacian_1 : torch.Tensor, shape = (-, -)
Lower Neighbourhood matrix.
up_laplacian_1 : torch.Tensor, shape = (-, -)
Upper neighbourhood matrix.
Returns
-------
torch.Tensor, shape = (num_pooled_edges, heads * out_channels)
Final hidden representations of pooled edges.
"""
adjacency_0 = adjacency_0.coalesce()
down_laplacian_1 = down_laplacian_1.coalesce()
up_laplacian_1 = up_laplacian_1.coalesce()
if hasattr(self, "lift_layer"):
x_1 = self.lift_layer(x_0, adjacency_0.coalesce(), x_1)
for layer in self.layers:
if isinstance(layer, PoolLayer):
x_1, down_laplacian_1, up_laplacian_1 = layer(
x_1, down_laplacian_1, up_laplacian_1
)
else:
x_1 = layer(x_1, down_laplacian_1, up_laplacian_1)
x_1 = F.dropout(x_1, p=0.5, training=self.training)
return x_1