Can_Layer#

Cell Attention Network layer.

class topomodelx.nn.cell.can_layer.CANLayer(in_channels: int, out_channels: int, heads: int = 1, dropout: float = 0.0, concat: bool = True, skip_connection: bool = True, att_activation: Module | None = None, add_self_loops: bool = True, aggr_func: Literal['mean', 'sum'] = 'sum', update_func: Literal['relu', 'sigmoid', 'tanh'] | None = 'relu', version: Literal['v1', 'v2'] = 'v1', share_weights: bool = False, **kwargs)[source]#

Layer of the Cell Attention Network (CAN) model.

The CAN layer considers an attention convolutional message passing though the upper and lower neighborhoods of the cell. Additionally, a skip connection can be added to the output of the layer.

Parameters:
in_channelsint

Dimension of input features on n-cells.

out_channelsint

Dimension of output.

headsint, default=1

Number of attention heads.

dropoutfloat, optional

Dropout probability of the normalized attention coefficients.

concatbool, default=True

If True, the output of each head is concatenated. Otherwise, the output of each head is averaged.

skip_connectionbool, default=True

If True, skip connection is added.

att_activationCallable, default=torch.nn.LeakyReLU()

Activation function applied to the attention coefficients.

add_self_loopsbool, optional

If True, self-loops are added to the neighborhood matrix.

aggr_funcLiteral[“mean”, “sum”], default=”sum”

Between-neighborhood aggregation function applied to the messages.

update_funcLiteral[“relu”, “sigmoid”, “tanh”, None], default=”relu”

Update function applied to the messages.

versionLiteral[“v1”, “v2”], default=”v1”

Version of the layer, by default “v1” which is the same as the original CAN layer. While “v2” has the same attetion mechanism as the GATv2 layer.

share_weightsbool, default=False

This option is valid only for “v2”. If True, the weights of the linear transformation applied to the source and target features are shared, by default False.

**kwargsoptional

Additional arguments of CAN layer.

Methods

add_module(name, module)

Adds a child module to the current module.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(x, down_laplacian_1, up_laplacian_1)

Forward pass.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

Reset the parameters of the layer.

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict(*args[, destination, prefix, ...])

Returns a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

__call__

Notes

Add_self_loops is preferred to be False. If necessary, the self-loops should be added to the neighborhood matrix in the preprocessing step.

forward(x, down_laplacian_1, up_laplacian_1) Tensor[source]#

Forward pass.

Parameters:
xtorch.Tensor, shape = (n_k_cells, channels)

Input features on the r-cell of the cell complex.

down_laplacian_1torch.sparse, shape = (n_k_cells, n_k_cells)

Lower neighborhood matrix mapping r-cells to r-cells (A_k_low).

up_laplacian_1torch.sparse, shape = (n_k_cells, n_k_cells)

Upper neighborhood matrix mapping r-cells to r-cells (A_k_up).

Returns:
torch.Tensor, shape = (n_k_cells, out_channels)

Output features on the r-cell of the cell complex.

Notes

\[\mathcal N = \{\mathcal N_1, \mathcal N_2\} = \{A_{\uparrow, r}, A_{\downarrow, r}\}\]
\[\begin{split}\begin{align*} &🟥 \quad m_{(y \rightarrow x),k}^{(r)} = \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k\\ &🟧 \quad m_{x,k}^{(r)} = \bigoplus_{y \in \mathcal{N}_k(x)} m^{(r)} _{(y \rightarrow x),k}\\ &🟩 \quad m_{x}^{(r)} = \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}^{(r)}\\ &🟦 \quad h_x^{t+1,(r)} = \phi^{t}(h_x^t, m_{x}^{(r)}) \end{align*}\end{split}\]
reset_parameters() None[source]#

Reset the parameters of the layer.

class topomodelx.nn.cell.can_layer.LiftLayer(in_channels_0: int, heads: int, signal_lift_activation: Callable, signal_lift_dropout: float)[source]#

Attentional Lift Layer.

This is adapted from the official implementation of the Cell Attention Network (CAN) [1].

Parameters:
in_channels_0int

Number of input channels of the node signal.

headsint

Number of attention heads.

signal_lift_activationCallable

Activation function applied to the lifted signal.

signal_lift_dropoutfloat

Dropout rate applied to the lifted signal.

Methods

add_module(name, module)

Adds a child module to the current module.

aggregate(x_message)

Aggregate messages on each target cell.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

attention(x_source[, x_target])

Compute attention weights for messages.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(x_0, adjacency_0)

Forward pass.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

message(x_source[, x_target])

Construct a message from source 0-cells to target 1-cell.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

Reinitialize learnable parameters using Xavier uniform initialization.

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict(*args[, destination, prefix, ...])

Returns a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

__call__

References

[1]

Giusti, Battiloro, Testa, Di Lorenzo, Sardellitti and Barbarossa. Cell attention networks (2022). Paper: https://arxiv.org/pdf/2209.08179.pdf Repository: lrnzgiusti/can

forward(x_0, adjacency_0) Tensor[source]#

Forward pass.

Parameters:
x_0torch.Tensor, shape = (num_nodes, in_channels_0)

Node signal.

adjacency_0torch.Tensor, shape = (num_nodes, num_nodes)

Sparse neighborhood matrix.

Returns:
torch.Tensor, shape = (num_edges, 1)

Edge signal.

message(x_source, x_target=None)[source]#

Construct a message from source 0-cells to target 1-cell.

Parameters:
x_sourcetorch.Tensor, shape = (num_edges, in_channels_0)

Node signal of the source 0-cells.

x_targettorch.Tensor, shape = (num_edges, in_channels_0)

Node signal of the target 1-cell.

Returns:
torch.Tensor, shape = (num_edges, heads)

Edge signal.

reset_parameters() None[source]#

Reinitialize learnable parameters using Xavier uniform initialization.

class topomodelx.nn.cell.can_layer.MultiHeadCellAttention(in_channels: int, out_channels: int, dropout: float, heads: int, concat: bool, att_activation: Module, add_self_loops: bool = False, aggr_func: Literal['sum', 'mean', 'add'] = 'sum', initialization: Literal['xavier_uniform', 'xavier_normal'] = 'xavier_uniform')[source]#

Attentional Message Passing v1.

Attentional Message Passing from Cell Attention Network (CAN) [1]_ following the attention mechanism proposed in GAT [2].

Parameters:
in_channelsint

Number of input channels.

out_channelsint

Number of output channels.

dropoutfloat

Dropout rate applied to the output signal.

headsint

Number of attention heads.

concatbool

Whether to concatenate the output of each attention head.

att_activationCallable

Activation function to use for the attention weights.

add_self_loopsbool, optional

Whether to add self-loops to the adjacency matrix.

aggr_funcLiteral[“sum”, “mean”, “add”], default=”sum”

Aggregation function to use.

initializationLiteral[“xavier_uniform”, “xavier_normal”], default=”xavier_uniform”

Initialization method for the weights of the layer.

Methods

add_module(name, module)

Adds a child module to the current module.

aggregate(x_message)

Aggregate messages on each target cell.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

attention(x_source, x_target)

Compute attention weights for messages.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(x_source, neighborhood)

Forward pass.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

message(x_source)

Construct message from source cells to target cells.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

Reset the layer parameters.

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict(*args[, destination, prefix, ...])

Returns a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

__call__

Notes

If there are no non-zero values in the neighborhood, then the neighborhood is empty and forward returns zeros Tensor.

References

[2]

Veličković, Cucurull, Casanova, Romero, Liò and Bengio. Graph attention networks (2017). https://arxiv.org/pdf/1710.10903.pdf

attention(x_source, x_target)[source]#

Compute attention weights for messages.

Parameters:
x_sourcetorch.Tensor, shape = [n_k_cells, in_channels]

Source node features.

x_targettorch.Tensor, shape = [n_k_cells, in_channels]

Target node features.

Returns:
torch.Tensor, shape = [n_k_cells, heads]

Attention weights.

forward(x_source, neighborhood)[source]#

Forward pass.

Parameters:
x_sourcetorch.Tensor, shape = (n_k_cells, channels)

Input features on the r-cell of the cell complex.

neighborhoodtorch.sparse, shape = (n_k_cells, n_k_cells)

Neighborhood matrix mapping r-cells to r-cells (A_k).

Returns:
torch.Tensor, shape = (n_k_cells, channels)

Output features on the r-cell of the cell complex.

message(x_source)[source]#

Construct message from source cells to target cells.

🟥 This provides a default message function to the message passing scheme.

Parameters:
x_sourcetorch.Tensor, shape = (n_k_cells, channels)

Input features on the r-cell of the cell complex.

Returns:
torch.Tensor, shape = (n_k_cells, heads, in_channels)

Messages on source cells.

reset_parameters() None[source]#

Reset the layer parameters.

class topomodelx.nn.cell.can_layer.MultiHeadCellAttention_v2(in_channels: int, out_channels: int, dropout: float, heads: int, concat: bool, att_activation: Module, add_self_loops: bool = True, aggr_func: Literal['sum', 'mean', 'add'] = 'sum', initialization: Literal['xavier_uniform', 'xavier_normal'] = 'xavier_uniform', share_weights: bool = False)[source]#

Attentional Message Passing v2.

Attentional Message Passing from Cell Attention Network (CAN) [1]_ following the attention mechanism proposed in GATv2 [3]

Parameters:
in_channelsint

Number of input channels.

out_channelsint

Number of output channels.

dropoutfloat

Dropout rate applied to the output signal.

headsint

Number of attention heads.

concatbool

Whether to concatenate the output of each attention head.

att_activationCallable

Activation function to use for the attention weights.

add_self_loopsbool, optional

Whether to add self-loops to the adjacency matrix.

aggr_funcLiteral[“sum”, “mean”, “add”], default=”sum”

Aggregation function to use.

initializationLiteral[“xavier_uniform”, “xavier_normal”], default=”xavier_uniform”

Initialization method for the weights of the layer.

share_weightsbool, optional

Whether to share the weights between the attention heads.

Methods

add_module(name, module)

Adds a child module to the current module.

aggregate(x_message)

Aggregate messages on each target cell.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

attention(x_source)

Compute attention weights for messages.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(x_source, neighborhood)

Forward pass.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

message(x_source)

Construct message from source cells to target cells.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

Reset the layer parameters.

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict(*args[, destination, prefix, ...])

Returns a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

__call__

Notes

If there are no non-zero values in the neighborhood, then the neighborhood is empty.

References

[3]

Brody, Alon, Yahav. How attentive are graph attention networks? (2022). https://arxiv.org/pdf/2105.14491.pdf

attention(x_source)[source]#

Compute attention weights for messages.

Parameters:
x_sourcetorch.Tensor, shape = (|n_k_cells|, heads, in_channels)

Source node features.

Returns:
torch.Tensor, shape = (n_k_cells, heads)

Attention weights.

forward(x_source, neighborhood)[source]#

Forward pass.

Parameters:
x_sourcetorch.Tensor, shape = (n_k_cells, channels)

Input features on the r-cell of the cell complex.

neighborhoodtorch.sparse, shape = (n_k_cells, n_k_cells)

Neighborhood matrix mapping r-cells to r-cells (A_k), [up, down].

Returns:
torch.Tensor, shape = (n_k_cells, channels)

Output features on the r-cell of the cell complex.

message(x_source)[source]#

Construct message from source cells to target cells.

🟥 This provides a default message function to the message passing scheme.

Parameters:
x_sourcetorch.Tensor, shape = (n_k_cells, channels)

Input features on the r-cell of the cell complex.

Returns:
Tensor, shape = (n_k_cells, heads, in_channels)

Messages on source cells.

reset_parameters() None[source]#

Reset the layer parameters.

class topomodelx.nn.cell.can_layer.MultiHeadLiftLayer(in_channels_0: int, heads: int = 1, signal_lift_activation: ~collections.abc.Callable = <built-in method relu of type object>, signal_lift_dropout: float = 0.0, signal_lift_readout: str = 'cat')[source]#

Multi Head Attentional Lift Layer.

Multi Head Attentional Lift Layer adapted from the official implementation of the Cell Attention Network (CAN) [1]_.

Parameters:
in_channels_0int

Number of input channels.

headsint, optional

Number of attention heads.

signal_lift_activationCallable, optional

Activation function to apply to the output edge signal.

signal_lift_dropoutfloat, optional

Dropout rate to apply to the output edge signal.

signal_lift_readoutstr, optional

Readout method to apply to the output edge signal.

Methods

add_module(name, module)

Adds a child module to the current module.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(x_0, adjacency_0[, x_1])

Forward pass.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

Reinitialize learnable parameters using Xavier uniform initialization.

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict(*args[, destination, prefix, ...])

Returns a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

__call__

forward(x_0, adjacency_0, x_1=None) Tensor[source]#

Forward pass.

Parameters:
x_0torch.Tensor, shape = (num_nodes, in_channels_0)

Node signal.

adjacency_0torch.Tensor, shape = (2, num_edges)

Edge index.

x_1torch.Tensor, shape = (num_edges, in_channels_1), optional

Edge signal.

Returns:
torch.Tensor, shape = (num_edges, heads + in_channels_1)

Lifted node signal.

Notes

\[\begin{split}\begin{align*} &🟥 \quad m_{(y,z) \rightarrow x}^{(0 \rightarrow 1)} = \alpha(h_y, h_z) = \Theta(h_z||h_y)\\ &🟦 \quad h_x^{(1)} = \phi(h_x, m_x^{(1)}) \end{align*}\end{split}\]
reset_parameters() None[source]#

Reinitialize learnable parameters using Xavier uniform initialization.

class topomodelx.nn.cell.can_layer.PoolLayer(k_pool: float, in_channels_0: int, signal_pool_activation: Callable, readout: bool = True)[source]#

Attentional Pooling Layer.

Attentional Pooling Layer adapted from the official implementation of the Cell Attention Network (CAN) [1]_.

Parameters:
k_poolfloat in (0, 1]

The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.

in_channels_0int

Number of input channels of the input signal.

signal_pool_activationCallable

Activation function applied to the pooled signal.

readoutbool, optional

Whether to apply a readout operation to the pooled signal.

Methods

add_module(name, module)

Adds a child module to the current module.

aggregate(x_message)

Aggregate messages on each target cell.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

attention(x_source[, x_target])

Compute attention weights for messages.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(x, down_laplacian_1, up_laplacian_1)

Forward pass.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

message(x_source[, x_target])

Construct message from source cells to target cells.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

Reinitialize learnable parameters using Xavier uniform initialization.

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict(*args[, destination, prefix, ...])

Returns a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

__call__

forward(x, down_laplacian_1, up_laplacian_1) tuple[Tensor, Tensor, Tensor][source]#

Forward pass.

Parameters:
xtorch.Tensor, shape = (n_r_cells, in_channels_r)

Input r-cell signal.

down_laplacian_1torch.Tensor

Lower neighborhood matrix.

up_laplacian_1torch.Tensor

Upper neighbourhood matrix.

Returns:
torch.Tensor

Pooled r_cell signal of shape (n_r_cells, in_channels_r).

Notes

\[\begin{split}\begin{align*} &🟥 \quad m_{x}^{(r)} = \gamma^t(h_x^t) = \tau^t (a^t\cdot h_x^t)\\ &🟦 \quad h_x^{t+1,(r)} = \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1} \end{align*}\end{split}\]
reset_parameters() None[source]#

Reinitialize learnable parameters using Xavier uniform initialization.

topomodelx.nn.cell.can_layer.add_self_loops(neighborhood)[source]#

Add self-loops to the neighborhood matrix.

Parameters:
neighborhoodtorch.sparse_coo_tensor, shape = (n_k_cells, n_k_cells)

Neighborhood matrix.

Returns:
torch.sparse_coo_tensor, shape = (n_k_cells, n_k_cells)

Neighborhood matrix with self-loops.

Notes

Add to utils file.

topomodelx.nn.cell.can_layer.softmax(src, index, num_cells: int)[source]#

Compute the softmax of the attention coefficients.

Parameters:
srctorch.Tensor, shape = (n_k_cells, heads)

Attention coefficients.

indextorch.Tensor, shape = (n_k_cells)

Indices of the target nodes.

num_cellsint

Number of cells in the batch.

Returns:
torch.Tensor, shape = (n_k_cells, heads)

Softmax of the attention coefficients.

Notes

There should be of a default implementation of softmax in the utils file. Subtracting the maximum element in it from all elements to avoid overflow and underflow.