Source code for topomodelx.base.aggregation

"""Aggregation module."""
from typing import Literal

import torch


[docs] class Aggregation(torch.nn.Module): """Message passing layer. Parameters ---------- aggr_func : {"mean", "sum"}, default="sum" Aggregation method (Inter-neighborhood). update_func : {"relu", "sigmoid", "tanh", None}, default="sigmoid" Update method to apply to merged message. """ def __init__( self, aggr_func: Literal["mean", "sum"] = "sum", update_func: Literal["relu", "sigmoid", "tanh"] | None = "sigmoid", ) -> None: super().__init__() self.aggr_func = aggr_func self.update_func = update_func
[docs] def update(self, inputs): """Update (Step 4). Parameters ---------- input : torch.Tensor Features for the update step. Returns ------- torch.Tensor Updated features with the same shape as input. """ if self.update_func == "sigmoid": return torch.sigmoid(inputs) if self.update_func == "relu": return torch.nn.functional.relu(inputs) if self.update_func == "tanh": return torch.tanh(inputs) return None
[docs] def forward(self, x): """Forward pass. Parameters ---------- x : list A list of messages to merge. Each message has shape [n_skeleton_in, channels] and len = n_messages_to_merge. Returns ------- torch.Tensor Aggregated messages. """ if self.aggr_func == "sum": x = torch.sum(torch.stack(x), axis=0) if self.aggr_func == "mean": x = torch.mean(torch.stack(x), axis=0) if self.update_func is not None: x = self.update(x) return x