Source code for topobenchmarkx.nn.readouts.base

"""Abstract base class for readout layers."""

from abc import abstractmethod

import torch
import torch_geometric
from torch_geometric.utils import scatter


[docs] class AbstractZeroCellReadOut(torch.nn.Module): r"""Readout layer for GNNs that operates on the batch level. Parameters ---------- hidden_dim : int Hidden dimension of the GNN model. out_channels : int Number of output channels. task_level : str Task level for readout layer. Either "graph" or "node". pooling_type : str Pooling type for readout layer. Either "max", "sum" or "mean". **kwargs : dict Additional arguments. """ def __init__( self, hidden_dim: int, out_channels: int, task_level: str, pooling_type: str = "sum", **kwargs, ): super().__init__() self.linear = torch.nn.Linear(hidden_dim, out_channels) assert task_level in ["graph", "node"], "Invalid task_level" self.task_level = task_level assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" self.pooling_type = pooling_type def __repr__(self): return f"{self.__class__.__name__}(task_level={self.task_level}, pooling_type={self.pooling_type})" def __call__( self, model_out: dict, batch: torch_geometric.data.Data ) -> dict: """Readout logic based on model_output. Parameters ---------- model_out : dict Dictionary containing the model output. batch : torch_geometric.data.Data Batch object containing the batched domain data. Returns ------- dict Dictionary containing the updated model output. """ model_out = self.forward(model_out, batch) model_out["logits"] = self.compute_logits( model_out["x_0"], batch["batch_0"] ) return model_out
[docs] def compute_logits(self, x, batch): r"""Compute logits based on the readout layer. Parameters ---------- x : torch.Tensor Node embeddings. batch : torch.Tensor Batch index tensor. Returns ------- torch.Tensor Logits tensor. """ if self.task_level == "graph": x = scatter(x, batch, dim=0, reduce=self.pooling_type) return self.linear(x)
[docs] @abstractmethod def forward(self, model_out: dict, batch: torch_geometric.data.Data): r"""Forward pass. Parameters ---------- model_out : dict Dictionary containing the model output. batch : torch_geometric.data.Data Batch object containing the batched domain data. """