Readouts#

Abstract base class for readout layers.

class topobenchmarkx.nn.readouts.base.AbstractZeroCellReadOut(hidden_dim: int, out_channels: int, task_level: str, pooling_type: str = 'sum', **kwargs)[source]#

Readout layer for GNNs that operates on the batch level.

Parameters:
hidden_dimint

Hidden dimension of the GNN model.

out_channelsint

Number of output channels.

task_levelstr

Task level for readout layer. Either “graph” or “node”.

pooling_typestr

Pooling type for readout layer. Either “max”, “sum” or “mean”.

**kwargsdict

Additional arguments.

compute_logits(x, batch)[source]#

Compute logits based on the readout layer.

Parameters:
xtorch.Tensor

Node embeddings.

batchtorch.Tensor

Batch index tensor.

Returns:
torch.Tensor

Logits tensor.

abstract forward(model_out: dict, batch: Data)[source]#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Readout layer that does not perform any operation on the node embeddings.

class topobenchmarkx.nn.readouts.identical.NoReadOut(**kwargs)[source]#

No readout layer.

This readout layer does not perform any operation on the node embeddings.

Parameters:
**kwargsdict, optional

Additional keyword arguments.

forward(model_out: dict, batch: Data) dict[source]#

Forward pass of the no readout layer.

It returns the model output without any modification.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output.

Readout layer that propagates the signal from cells of a certain order to the cells of the lower order.

class topobenchmarkx.nn.readouts.propagate_signal_down.PropagateSignalDown(**kwargs)[source]#

Propagate signal down readout layer.

This readout layer propagates the signal from cells of a certain order to the cells of the lower order.

Parameters:
**kwargsdict

Additional keyword arguments. It should contain the following keys: - num_cell_dimensions (int): Highest order of cells considered by the model. - hidden_dim (int): Dimension of the cells representations. - readout_name (str): Readout name.

forward(model_out: dict, batch: Data)[source]#

Forward pass of the propagate signal down readout layer.

The layer takes the embeddings of the cells of a certain order and applies a convolutional layer to them. Layer normalization is then applied to the features. The output is concatenated with the initial embeddings of the cells and the result is projected with the use of a linear layer to the dimensions of the cells of lower rank. The process is repeated until the nodes embeddings, which are the cells of rank 0, are reached.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the updated model output.